Don't enable prefix caching on VLM just yet.

This commit is contained in:
Nicolas Patry 2024-08-27 09:58:19 +02:00
parent e30fb25444
commit f1c0735453
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -68,16 +68,16 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok(); let mut prefix_caching: Option<String> = std::env::var("USE_PREFIX_CACHING").ok();
let mut attention: Option<String> = std::env::var("ATTENTION").ok(); let mut attention: Option<String> = std::env::var("ATTENTION").ok();
if let Some(config) = config { if let Some(config) = config {
if config.vision_config.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
}
match config.head_dim { match config.head_dim {
Some(h) if h == 64 || h == 128 || h == 256 => { Some(h) if h == 64 || h == 128 || h == 256 => {
if lora_adapters.is_some() && prefix_caching.is_none() { if lora_adapters.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of lora adapters"); tracing::info!("Disabling prefix caching because of lora adapters");
prefix_caching = Some("0".to_string()); prefix_caching = Some("0".to_string());
} }
if config.vision_config.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() { match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by // Required because gemma2 needs bfloat16 which is not supported by