mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-09 02:42:05 +00:00
fix: improve processor logic and refactor
This commit is contained in:
parent
bb5c875f0b
commit
9eeccbf9a5
@ -1192,7 +1192,6 @@ pub(crate) async fn chat_completions(
|
|||||||
let (generate_request, using_tools): (GenerateRequest, bool) =
|
let (generate_request, using_tools): (GenerateRequest, bool) =
|
||||||
chat.clone().try_into_generate(&infer)?;
|
chat.clone().try_into_generate(&infer)?;
|
||||||
span.record("parameters", format!("{:?}", generate_request.parameters));
|
span.record("parameters", format!("{:?}", generate_request.parameters));
|
||||||
println!("ChatRequest: {:#?}", generate_request);
|
|
||||||
let logprobs = logprobs.unwrap_or_default();
|
let logprobs = logprobs.unwrap_or_default();
|
||||||
|
|
||||||
// extract model id from request if specified
|
// extract model id from request if specified
|
||||||
|
@ -34,6 +34,7 @@ pub struct Validation {
|
|||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
|
vocab_size: u32,
|
||||||
/// Channel to communicate with the background tokenization task
|
/// Channel to communicate with the background tokenization task
|
||||||
sender: mpsc::UnboundedSender<TokenizerRequest>,
|
sender: mpsc::UnboundedSender<TokenizerRequest>,
|
||||||
}
|
}
|
||||||
@ -88,6 +89,19 @@ impl Validation {
|
|||||||
validation_sender
|
validation_sender
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let vocab_size = match &tokenizer {
|
||||||
|
Tokenizer::Python { tokenizer_name, .. } => {
|
||||||
|
warn!(
|
||||||
|
"Tokenizer {} is not supported for validation",
|
||||||
|
tokenizer_name
|
||||||
|
);
|
||||||
|
0
|
||||||
|
}
|
||||||
|
Tokenizer::Rust(tokenizer) => tokenizer.get_vocab_size(false),
|
||||||
|
}
|
||||||
|
.try_into()
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
max_best_of,
|
max_best_of,
|
||||||
sender,
|
sender,
|
||||||
@ -96,6 +110,7 @@ impl Validation {
|
|||||||
max_input_length,
|
max_input_length,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
|
vocab_size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -409,6 +424,35 @@ impl Validation {
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let logit_bias = match &request.parameters.logit_bias {
|
||||||
|
Some(bias) if !bias.is_empty() => {
|
||||||
|
for (token_str, _) in bias.iter() {
|
||||||
|
let token_id = token_str.parse::<u32>().map_err(|_| {
|
||||||
|
ValidationError::LogitBiasInvalid(format!(
|
||||||
|
"Token ID {} is not a valid number.",
|
||||||
|
token_str
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if token_id >= self.vocab_size {
|
||||||
|
return Err(ValidationError::LogitBiasInvalid(format!(
|
||||||
|
"Token ID {} is out of range. Must be between 0 and {}.",
|
||||||
|
token_id,
|
||||||
|
self.vocab_size - 1
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transform into the required format
|
||||||
|
Some(
|
||||||
|
bias.iter()
|
||||||
|
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
let parameters = ValidParameters {
|
let parameters = ValidParameters {
|
||||||
temperature,
|
temperature,
|
||||||
repetition_penalty,
|
repetition_penalty,
|
||||||
@ -420,18 +464,7 @@ impl Validation {
|
|||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
grammar,
|
grammar,
|
||||||
logit_bias: Some(
|
logit_bias,
|
||||||
request
|
|
||||||
.parameters
|
|
||||||
.logit_bias
|
|
||||||
.iter()
|
|
||||||
.flat_map(|bias| {
|
|
||||||
bias.iter()
|
|
||||||
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
let stopping_parameters = ValidStoppingParameters {
|
let stopping_parameters = ValidStoppingParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
@ -1011,6 +1044,8 @@ pub enum ValidationError {
|
|||||||
FailedFetchImage(#[from] reqwest::Error),
|
FailedFetchImage(#[from] reqwest::Error),
|
||||||
#[error("{0} modality is not supported")]
|
#[error("{0} modality is not supported")]
|
||||||
UnsupportedModality(&'static str),
|
UnsupportedModality(&'static str),
|
||||||
|
#[error("logit_bias is not valid: {0}")]
|
||||||
|
LogitBiasInvalid(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -625,55 +625,49 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class LogitBiasProcessor:
|
class LogitBiasProcessor(LogitsProcessor):
|
||||||
"""Process logits with logit biases."""
|
"""
|
||||||
|
`LogitsProcessor` creates a bias tensor from a dictionary of token IDs and their
|
||||||
|
corresponding bias values. Bias are applied to the logits during each forward pass.
|
||||||
|
|
||||||
|
Supports token IDs provided as strings (e.g., {"9707": -100}).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, logit_biases: Optional[dict], tokenizer: PreTrainedTokenizerBase
|
self,
|
||||||
|
logit_biases: dict,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
device: torch.device,
|
||||||
):
|
):
|
||||||
self.tokenizer = tokenizer
|
assert logit_biases, "LogitBiasProcessor requires non-empty logit_biases"
|
||||||
self.logit_biases = logit_biases or {}
|
|
||||||
|
|
||||||
# Pre-compute token IDs for each token string
|
vocab_size = len(tokenizer)
|
||||||
self.token_id_mapping = {}
|
|
||||||
|
# Convert keys to integers and values to a list
|
||||||
|
token_ids = torch.tensor(
|
||||||
|
[int(k) for k in logit_biases.keys()], dtype=torch.long
|
||||||
|
)
|
||||||
|
bias_values = torch.tensor(list(logit_biases.values()), dtype=torch.float)
|
||||||
|
|
||||||
|
# Create a tensor and directly copy bias values at the corresponding indices
|
||||||
|
self.bias_tensor = torch.zeros(vocab_size, dtype=torch.float)
|
||||||
|
self.bias_tensor.index_put_((token_ids,), bias_values, accumulate=True)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
# If no logit biases, return scores unchanged
|
# Apply bias tensor as a broadcasted addition
|
||||||
if not self.logit_biases:
|
if self.bias_tensor.shape[0] != scores.shape[1]:
|
||||||
|
# Fix if the bias tensor is smaller than the scores
|
||||||
|
self.bias_tensor = torch.nn.functional.pad(
|
||||||
|
self.bias_tensor, (0, scores.shape[1] - self.bias_tensor.shape[0])
|
||||||
|
)
|
||||||
|
scores.add_(self.bias_tensor.to(device=scores.device, dtype=scores.dtype))
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
# Apply bias to the corresponding scores
|
|
||||||
for token_str, bias_value in self.logit_biases.items():
|
|
||||||
# Get token ID, either from cache or by computing it
|
|
||||||
if token_str not in self.token_id_mapping:
|
|
||||||
if token_str.isdigit():
|
|
||||||
# If the token string is already a numeric ID
|
|
||||||
token_id = int(token_str)
|
|
||||||
else:
|
|
||||||
# Otherwise, use the tokenizer to get the ID
|
|
||||||
tokens = self.tokenizer.encode(token_str, add_special_tokens=False)
|
|
||||||
token_id = tokens[0] if tokens else -1 # Use -1 for not found
|
|
||||||
|
|
||||||
self.token_id_mapping[token_str] = token_id
|
class HeterogeneousLogitBiasProcessor(LogitsProcessor):
|
||||||
|
"""
|
||||||
token_id = self.token_id_mapping[token_str]
|
Process logits with different logit biases for each sequence in the batch.
|
||||||
|
"""
|
||||||
# Apply bias if token ID is valid
|
|
||||||
if 0 <= token_id < scores.size(-1):
|
|
||||||
scores[:, token_id] += bias_value
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
||||||
def filter(self, indices):
|
|
||||||
"""Keep only the logit biases for the specified indices."""
|
|
||||||
new_logit_biases = {
|
|
||||||
k: self.logit_biases[k] for k in indices if k in self.logit_biases
|
|
||||||
}
|
|
||||||
return LogitBiasProcessor(new_logit_biases, self.tokenizer)
|
|
||||||
|
|
||||||
|
|
||||||
class HeterogeneousLogitBiasProcessor:
|
|
||||||
"""Process logits with different logit biases for each sequence in the batch."""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -681,54 +675,42 @@ class HeterogeneousLogitBiasProcessor:
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
self.device = device
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.logit_biases = logit_biases
|
self.logit_biases = logit_biases
|
||||||
self.batch_size = len(logit_biases)
|
# import ipdb; ipdb.set_trace()
|
||||||
|
self.vocab_size = len(tokenizer)
|
||||||
|
|
||||||
# Pre-compute token IDs for each token string
|
# Create batch_size x vocab_size bias matrix
|
||||||
self.token_id_mapping = {}
|
self.bias_matrix = torch.zeros(
|
||||||
|
(len(logit_biases), self.vocab_size), dtype=torch.float, device=device
|
||||||
|
)
|
||||||
|
|
||||||
# Create a mapping of indices that have logit biases
|
# for each logit bias dictionary, convert keys to integers and values to a list
|
||||||
self.indices_with_biases = {
|
for i, logit_bias in enumerate(logit_biases):
|
||||||
i: bias_dict
|
token_ids = torch.tensor(
|
||||||
for i, bias_dict in enumerate(self.logit_biases)
|
[int(k) for k in logit_bias.keys()], dtype=torch.long
|
||||||
if bias_dict is not None and len(bias_dict) > 0
|
).to(device=device)
|
||||||
}
|
bias_values = torch.tensor(list(logit_bias.values()), dtype=torch.float).to(
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
# Create a tensor and directly copy bias values at the corresponding indices
|
||||||
|
self.bias_matrix[i].index_put_((token_ids,), bias_values, accumulate=True)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||||
# If no indices have biases, return scores unchanged
|
# Apply bias matrix as a broadcasted addition
|
||||||
if not self.indices_with_biases:
|
if self.bias_matrix.shape[1] != scores.shape[1]:
|
||||||
return scores
|
# Fix if the bias matrix is smaller than the scores
|
||||||
|
self.bias_matrix = torch.nn.functional.pad(
|
||||||
# For each index with a bias, apply the bias to the corresponding scores
|
self.bias_matrix, (0, scores.shape[1] - self.bias_matrix.shape[1])
|
||||||
for i, bias_dict in self.indices_with_biases.items():
|
|
||||||
for token_str, bias_value in bias_dict.items():
|
|
||||||
# Get token ID, either from cache or by computing it
|
|
||||||
if token_str not in self.token_id_mapping:
|
|
||||||
if token_str.isdigit():
|
|
||||||
# If the token string is already a numeric ID
|
|
||||||
token_id = int(token_str)
|
|
||||||
else:
|
|
||||||
# Otherwise, use the tokenizer to get the ID
|
|
||||||
tokens = self.tokenizer.encode(
|
|
||||||
token_str, add_special_tokens=False
|
|
||||||
)
|
)
|
||||||
token_id = tokens[0] if tokens else -1 # Use -1 for not found
|
|
||||||
|
|
||||||
self.token_id_mapping[token_str] = token_id
|
|
||||||
|
|
||||||
token_id = self.token_id_mapping[token_str]
|
|
||||||
|
|
||||||
# Apply bias if token ID is valid
|
|
||||||
if 0 <= token_id < scores.size(-1):
|
|
||||||
scores[i, token_id] += bias_value
|
|
||||||
|
|
||||||
|
scores.add_(self.bias_matrix.to(device=scores.device, dtype=scores.dtype))
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def filter(self, indices: List[int]):
|
def filter(self, indices):
|
||||||
"""Keep only the logit biases for the specified indices."""
|
|
||||||
new_logit_biases = [self.logit_biases[i] for i in indices]
|
new_logit_biases = [self.logit_biases[i] for i in indices]
|
||||||
|
if not any(bias and len(bias) > 0 for bias in new_logit_biases):
|
||||||
|
return None
|
||||||
return HeterogeneousLogitBiasProcessor(
|
return HeterogeneousLogitBiasProcessor(
|
||||||
new_logit_biases, self.tokenizer, self.device
|
new_logit_biases, self.tokenizer, self.device
|
||||||
)
|
)
|
||||||
|
@ -66,7 +66,6 @@ class NextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.logit_bias = logit_bias
|
|
||||||
|
|
||||||
has_warpers = (
|
has_warpers = (
|
||||||
(temperature is not None and temperature != 1.0)
|
(temperature is not None and temperature != 1.0)
|
||||||
@ -136,7 +135,7 @@ class NextTokenChooser:
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammar=pb.grammar,
|
grammar=pb.grammar,
|
||||||
grammar_type=pb.grammar_type,
|
grammar_type=pb.grammar_type,
|
||||||
logit_bias=dict(pb.logit_bias) if pb.logit_bias else None,
|
logit_bias=pb.logit_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -264,10 +263,6 @@ class HeterogeneousNextTokenChooser:
|
|||||||
):
|
):
|
||||||
warpers = []
|
warpers = []
|
||||||
|
|
||||||
# Initialize with empty logit biases if none provided
|
|
||||||
if logit_biases is None:
|
|
||||||
logit_biases = [None] * len(do_sample)
|
|
||||||
|
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
HeterogeneousProcessorWrapper(
|
HeterogeneousProcessorWrapper(
|
||||||
{
|
{
|
||||||
@ -306,7 +301,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
self.logit_bias_processor = (
|
self.logit_bias_processor = (
|
||||||
HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device)
|
HeterogeneousLogitBiasProcessor(logit_biases, tokenizer, device)
|
||||||
if any([bias is not None and len(bias) > 0 for bias in logit_biases])
|
if any([logit_bias is not None for logit_bias in logit_biases])
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -530,9 +525,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
fsm_grammar_states=(
|
fsm_grammar_states=(
|
||||||
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
||||||
),
|
),
|
||||||
logit_biases=[
|
logit_biases=[pb_.logit_bias for pb_ in pb],
|
||||||
dict(pb_.logit_bias) if pb_.logit_bias else None for pb_ in pb
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user