diff --git a/router/src/validation.rs b/router/src/validation.rs index bda19224..b3d4dd9a 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -424,34 +424,36 @@ impl Validation { None => None, }; - let logit_bias = match &request.parameters.logit_bias { - Some(bias) if !bias.is_empty() => { - for (token_str, _) in bias.iter() { - let token_id = token_str.parse::().map_err(|_| { - ValidationError::LogitBiasInvalid(format!( - "Token ID {} is not a valid number.", - token_str - )) - })?; + // Validate logit bias and convert to a vector of (token_id, bias_value) + let logit_bias = request + .parameters + .logit_bias + .as_ref() + .filter(|bias_map| !bias_map.is_empty()) + .map(|bias_map| { + bias_map + .iter() + .map(|(token_str, &bias_value)| { + let token_id: u32 = token_str.parse().map_err(|_| { + ValidationError::LogitBiasInvalid(format!( + "Token ID {token_str} is not a valid number." + )) + })?; - if token_id >= self.vocab_size { - return Err(ValidationError::LogitBiasInvalid(format!( - "Token ID {} is out of range. Must be between 0 and {}.", - token_id, - self.vocab_size - 1 - ))); - } - } + if token_id >= self.vocab_size { + return Err(ValidationError::LogitBiasInvalid(format!( + "Token ID {token_id} is out of range (0..{}).", + self.vocab_size - 1 + ))); + } - // Transform into the required format - Some( - bias.iter() - .map(|(k, v)| (k.parse::().unwrap(), *v as f32)) - .collect(), - ) - } - _ => None, - }; + Ok((token_id, bias_value as f32)) + }) + .collect::, _>>() + }) + // convert Option> to Result, E> to throw + // if any of the token IDs are invalid + .transpose()?; let parameters = ValidParameters { temperature,