diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a4bb2e3d..3f1b37d6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -22,7 +22,8 @@ mod env_runtime; #[derive(Deserialize)] struct Config { - max_position_embeddings: usize, + max_position_embeddings: Option, + max_seq_len: Option, } #[derive(Clone, Copy, Debug, ValueEnum)] @@ -817,6 +818,14 @@ enum LauncherError { WebserverCannotStart, } +impl core::fmt::Display for LauncherError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(f, "{self:?}") + } +} + +impl std::error::Error for LauncherError {} + fn download_convert_model(args: &Args, running: Arc) -> Result<(), LauncherError> { // Enter download tracing span let _span = tracing::span!(tracing::Level::INFO, "download").entered(); @@ -1291,12 +1300,21 @@ fn main() -> Result<(), LauncherError> { 2usize.pow(14) }; - let max_position_embeddings = if config.max_position_embeddings > max_default { - let max = config.max_position_embeddings; - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max - 1, max - 1); - max_default - } else { - config.max_position_embeddings + let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) { + (Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); + max_default + } else { + max_position_embeddings + } + } + _ => { + return Err(Box::new(LauncherError::ArgumentValidation( + "no max defined".to_string(), + ))); + } }; Ok(max_position_embeddings) };