From e30fb254449ec738d6a6fcda4fd3da57acfd2ea6 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 26 Aug 2024 22:45:04 +0200 Subject: [PATCH] Fixing the default for vlm. --- backends/v3/src/radix.rs | 17 +++++++---------- launcher/src/main.rs | 17 +++++++++++++---- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 6f6d61a9a..c32ba64fb 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -606,8 +606,8 @@ mod tests { cache.free(allocation.blocks.clone(), allocation.allocation_id); let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - assert_eq!(allocation.blocks, vec![8, 9, 6, 7]); - assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); assert_eq!(allocation.prefix_len, 4); } @@ -618,15 +618,12 @@ mod tests { assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); assert_eq!(allocation.prefix_len, 0); - cache.free( - allocation.blocks[..allocation.blocks.len() - 1].to_vec(), - allocation.allocation_id, - ); + cache.free(allocation.blocks.clone(), allocation.allocation_id); - let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - assert_eq!(allocation.blocks, vec![8, 9, 6, 7]); - assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]); - assert_eq!(allocation.prefix_len, 4); + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 2); } #[test] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 557b3f8c7..f26251123 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -74,16 +74,19 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> tracing::info!("Disabling prefix caching because of lora adapters"); prefix_caching = Some("0".to_string()); } + if config.vision_config.is_some() && prefix_caching.is_none() { + tracing::info!("Disabling prefix caching because of VLM model"); + prefix_caching = Some("0".to_string()); + } match config.model_type.as_deref() { Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { // Required because gemma2 needs bfloat16 which is not supported by // flashinfer ? - if prefix_caching.is_none() { + if attention.is_none() { tracing::info!( "Forcing flash decoding because model {} requires it", config.model_type.as_ref().unwrap() ); - prefix_caching = Some("1".to_string()); attention = Some("flashdecoding".to_string()); } } @@ -91,9 +94,8 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> } } _ => { - if prefix_caching.is_none() { + if attention.is_none() { tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); - prefix_caching = Some("1".to_string()); attention = Some("flashdecoding".to_string()); } } @@ -115,6 +117,7 @@ struct RawConfig { hidden_size: Option, num_attention_heads: Option, head_dim: Option, + vision_config: Option, } #[derive(Deserialize)] @@ -122,12 +125,16 @@ struct QuantizationConfig { quant_method: Option, } +#[derive(Deserialize)] +struct VisionConfig {} + #[derive(Deserialize)] struct Config { max_position_embeddings: Option, quantize: Option, head_dim: Option, model_type: Option, + vision_config: Option, } impl From for Config { @@ -154,11 +161,13 @@ impl From for Config { } }); let model_type = other.model_type; + let vision_config = other.vision_config; Config { max_position_embeddings, quantize, head_dim, model_type, + vision_config, } } }