diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f2625112..1eb1d83d 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -68,16 +68,16 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> let mut prefix_caching: Option = std::env::var("USE_PREFIX_CACHING").ok(); let mut attention: Option = std::env::var("ATTENTION").ok(); if let Some(config) = config { + if config.vision_config.is_some() && prefix_caching.is_none() { + tracing::info!("Disabling prefix caching because of VLM model"); + prefix_caching = Some("0".to_string()); + } match config.head_dim { Some(h) if h == 64 || h == 128 || h == 256 => { if lora_adapters.is_some() && prefix_caching.is_none() { tracing::info!("Disabling prefix caching because of lora adapters"); prefix_caching = Some("0".to_string()); } - if config.vision_config.is_some() && prefix_caching.is_none() { - tracing::info!("Disabling prefix caching because of VLM model"); - prefix_caching = Some("0".to_string()); - } match config.model_type.as_deref() { Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { // Required because gemma2 needs bfloat16 which is not supported by