mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-27 02:10:17 +00:00
fix: improve validation and transform logic
This commit is contained in:
parent
465294d3de
commit
7659925d85
@ -424,34 +424,36 @@ impl Validation {
|
||||
None => None,
|
||||
};
|
||||
|
||||
let logit_bias = match &request.parameters.logit_bias {
|
||||
Some(bias) if !bias.is_empty() => {
|
||||
for (token_str, _) in bias.iter() {
|
||||
let token_id = token_str.parse::<u32>().map_err(|_| {
|
||||
ValidationError::LogitBiasInvalid(format!(
|
||||
"Token ID {} is not a valid number.",
|
||||
token_str
|
||||
))
|
||||
})?;
|
||||
// Validate logit bias and convert to a vector of (token_id, bias_value)
|
||||
let logit_bias = request
|
||||
.parameters
|
||||
.logit_bias
|
||||
.as_ref()
|
||||
.filter(|bias_map| !bias_map.is_empty())
|
||||
.map(|bias_map| {
|
||||
bias_map
|
||||
.iter()
|
||||
.map(|(token_str, &bias_value)| {
|
||||
let token_id: u32 = token_str.parse().map_err(|_| {
|
||||
ValidationError::LogitBiasInvalid(format!(
|
||||
"Token ID {token_str} is not a valid number."
|
||||
))
|
||||
})?;
|
||||
|
||||
if token_id >= self.vocab_size {
|
||||
return Err(ValidationError::LogitBiasInvalid(format!(
|
||||
"Token ID {} is out of range. Must be between 0 and {}.",
|
||||
token_id,
|
||||
self.vocab_size - 1
|
||||
)));
|
||||
}
|
||||
}
|
||||
if token_id >= self.vocab_size {
|
||||
return Err(ValidationError::LogitBiasInvalid(format!(
|
||||
"Token ID {token_id} is out of range (0..{}).",
|
||||
self.vocab_size - 1
|
||||
)));
|
||||
}
|
||||
|
||||
// Transform into the required format
|
||||
Some(
|
||||
bias.iter()
|
||||
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
Ok((token_id, bias_value as f32))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
// convert Option<Result<T, E>> to Result<Option<T>, E> to throw
|
||||
// if any of the token IDs are invalid
|
||||
.transpose()?;
|
||||
|
||||
let parameters = ValidParameters {
|
||||
temperature,
|
||||
|
Loading…
Reference in New Issue
Block a user