Fixing the default for vlm.

This commit is contained in:
Nicolas Patry 2024-08-26 22:45:04 +02:00
parent 27b566baa8
commit e30fb25444
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
2 changed files with 20 additions and 14 deletions

View File

@ -606,8 +606,8 @@ mod tests {
cache.free(allocation.blocks.clone(), allocation.allocation_id); cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 6, 7]); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 4); assert_eq!(allocation.prefix_len, 4);
} }
@ -618,15 +618,12 @@ mod tests {
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 0); assert_eq!(allocation.prefix_len, 0);
cache.free( cache.free(allocation.blocks.clone(), allocation.allocation_id);
allocation.blocks[..allocation.blocks.len() - 1].to_vec(),
allocation.allocation_id,
);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 6, 7]); assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]); assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 4); assert_eq!(allocation.prefix_len, 2);
} }
#[test] #[test]

View File

@ -74,16 +74,19 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
tracing::info!("Disabling prefix caching because of lora adapters"); tracing::info!("Disabling prefix caching because of lora adapters");
prefix_caching = Some("0".to_string()); prefix_caching = Some("0".to_string());
} }
if config.vision_config.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() { match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by // Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ? // flashinfer ?
if prefix_caching.is_none() { if attention.is_none() {
tracing::info!( tracing::info!(
"Forcing flash decoding because model {} requires it", "Forcing flash decoding because model {} requires it",
config.model_type.as_ref().unwrap() config.model_type.as_ref().unwrap()
); );
prefix_caching = Some("1".to_string());
attention = Some("flashdecoding".to_string()); attention = Some("flashdecoding".to_string());
} }
} }
@ -91,9 +94,8 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
} }
} }
_ => { _ => {
if prefix_caching.is_none() { if attention.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
prefix_caching = Some("1".to_string());
attention = Some("flashdecoding".to_string()); attention = Some("flashdecoding".to_string());
} }
} }
@ -115,6 +117,7 @@ struct RawConfig {
hidden_size: Option<usize>, hidden_size: Option<usize>,
num_attention_heads: Option<usize>, num_attention_heads: Option<usize>,
head_dim: Option<usize>, head_dim: Option<usize>,
vision_config: Option<VisionConfig>,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -122,12 +125,16 @@ struct QuantizationConfig {
quant_method: Option<Quantization>, quant_method: Option<Quantization>,
} }
#[derive(Deserialize)]
struct VisionConfig {}
#[derive(Deserialize)] #[derive(Deserialize)]
struct Config { struct Config {
max_position_embeddings: Option<usize>, max_position_embeddings: Option<usize>,
quantize: Option<Quantization>, quantize: Option<Quantization>,
head_dim: Option<usize>, head_dim: Option<usize>,
model_type: Option<String>, model_type: Option<String>,
vision_config: Option<VisionConfig>,
} }
impl From<RawConfig> for Config { impl From<RawConfig> for Config {
@ -154,11 +161,13 @@ impl From<RawConfig> for Config {
} }
}); });
let model_type = other.model_type; let model_type = other.model_type;
let vision_config = other.vision_config;
Config { Config {
max_position_embeddings, max_position_embeddings,
quantize, quantize,
head_dim, head_dim,
model_type, model_type,
vision_config,
} }
} }
} }