diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 5a1ca8ce..583220a6 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -103,6 +103,10 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> ); attention = Some(fallback_attention.to_string()); } + if fallback_attention == "paged" && prefix_caching.is_none() { + tracing::info!("Disabling prefix caching because it is not supported with 'paged' attention"); + prefix_caching = Some("0".to_string()); + } } Some("t5") => {} _ => {} @@ -119,16 +123,9 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> } } } + let attention = attention.unwrap_or("flashinfer".to_string()); - let prefix_caching = if attention == "paged" - && prefix_caching.is_none() - && compute_capability.is_some() - { - tracing::info!("Disabling prefix caching because it is not supported with 'flashinfer'"); - "false".to_string() - } else { - prefix_caching.unwrap_or("true".to_string()) - }; + let prefix_caching = prefix_caching.unwrap_or("true".to_string()); (prefix_caching, attention) }