fix: improve validation and transform logic

This commit is contained in:
drbh 2025-05-05 13:59:02 -04:00
parent 465294d3de
commit 7659925d85

View File

@ -424,34 +424,36 @@ impl Validation {
None => None,
};
let logit_bias = match &request.parameters.logit_bias {
Some(bias) if !bias.is_empty() => {
for (token_str, _) in bias.iter() {
let token_id = token_str.parse::<u32>().map_err(|_| {
ValidationError::LogitBiasInvalid(format!(
"Token ID {} is not a valid number.",
token_str
))
})?;
// Validate logit bias and convert to a vector of (token_id, bias_value)
let logit_bias = request
.parameters
.logit_bias
.as_ref()
.filter(|bias_map| !bias_map.is_empty())
.map(|bias_map| {
bias_map
.iter()
.map(|(token_str, &bias_value)| {
let token_id: u32 = token_str.parse().map_err(|_| {
ValidationError::LogitBiasInvalid(format!(
"Token ID {token_str} is not a valid number."
))
})?;
if token_id >= self.vocab_size {
return Err(ValidationError::LogitBiasInvalid(format!(
"Token ID {} is out of range. Must be between 0 and {}.",
token_id,
self.vocab_size - 1
)));
}
}
if token_id >= self.vocab_size {
return Err(ValidationError::LogitBiasInvalid(format!(
"Token ID {token_id} is out of range (0..{}).",
self.vocab_size - 1
)));
}
// Transform into the required format
Some(
bias.iter()
.map(|(k, v)| (k.parse::<u32>().unwrap(), *v as f32))
.collect(),
)
}
_ => None,
};
Ok((token_id, bias_value as f32))
})
.collect::<Result<Vec<_>, _>>()
})
// convert Option<Result<T, E>> to Result<Option<T>, E> to throw
// if any of the token IDs are invalid
.transpose()?;
let parameters = ValidParameters {
temperature,