Making things work most of the time.

This commit is contained in:
Nicolas Patry 2024-04-11 18:30:38 +00:00
parent 9ce9f39dea
commit d43e10e097
2 changed files with 14 additions and 7 deletions

View File

@ -1290,7 +1290,12 @@ fn main() -> Result<(), LauncherError> {
let content = std::fs::read_to_string(filename)?; let content = std::fs::read_to_string(filename)?;
let config: Config = serde_json::from_str(&content)?; let config: Config = serde_json::from_str(&content)?;
let max_default = 2usize.pow(14); // Quantization usually means you're even more RAM constrained.
let max_default = if args.quantize.is_some() {
4096
} else {
2usize.pow(14)
};
let max_position_embeddings = if config.max_position_embeddings > max_default { let max_position_embeddings = if config.max_position_embeddings > max_default {
let max = config.max_position_embeddings; let max = config.max_position_embeddings;

View File

@ -163,13 +163,15 @@ impl Validation {
}; };
let input_length = truncate.unwrap_or(self.max_input_length); let input_length = truncate.unwrap_or(self.max_input_length);
// We don't have a tokenizer, therefore we have no idea how long is the query, let
// them through and hope for the best.
// Validate MaxNewTokens // Validate MaxNewTokens
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { // if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
return Err(ValidationError::MaxNewTokens( // return Err(ValidationError::MaxNewTokens(
self.max_total_tokens - self.max_input_length, // self.max_total_tokens - self.max_input_length,
max_new_tokens, // max_new_tokens,
)); // ));
} // }
Ok((inputs, input_length, max_new_tokens)) Ok((inputs, input_length, max_new_tokens))
} }