Fixing the default for vlm.

This commit is contained in:
Nicolas Patry 2024-08-26 22:45:04 +02:00
parent 27b566baa8
commit e30fb25444
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
2 changed files with 20 additions and 14 deletions

View File

@ -606,8 +606,8 @@ mod tests {
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 6, 7]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]);
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
assert_eq!(allocation.prefix_len, 4);
}
@ -618,15 +618,12 @@ mod tests {
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 0);
cache.free(
allocation.blocks[..allocation.blocks.len() - 1].to_vec(),
allocation.allocation_id,
);
cache.free(allocation.blocks.clone(), allocation.allocation_id);
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 6, 7]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]);
assert_eq!(allocation.prefix_len, 4);
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
assert_eq!(allocation.prefix_len, 2);
}
#[test]

View File

@ -74,16 +74,19 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
tracing::info!("Disabling prefix caching because of lora adapters");
prefix_caching = Some("0".to_string());
}
if config.vision_config.is_some() && prefix_caching.is_none() {
tracing::info!("Disabling prefix caching because of VLM model");
prefix_caching = Some("0".to_string());
}
match config.model_type.as_deref() {
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
// Required because gemma2 needs bfloat16 which is not supported by
// flashinfer ?
if prefix_caching.is_none() {
if attention.is_none() {
tracing::info!(
"Forcing flash decoding because model {} requires it",
config.model_type.as_ref().unwrap()
);
prefix_caching = Some("1".to_string());
attention = Some("flashdecoding".to_string());
}
}
@ -91,9 +94,8 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
}
}
_ => {
if prefix_caching.is_none() {
if attention.is_none() {
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
prefix_caching = Some("1".to_string());
attention = Some("flashdecoding".to_string());
}
}
@ -115,6 +117,7 @@ struct RawConfig {
hidden_size: Option<usize>,
num_attention_heads: Option<usize>,
head_dim: Option<usize>,
vision_config: Option<VisionConfig>,
}
#[derive(Deserialize)]
@ -122,12 +125,16 @@ struct QuantizationConfig {
quant_method: Option<Quantization>,
}
#[derive(Deserialize)]
struct VisionConfig {}
#[derive(Deserialize)]
struct Config {
max_position_embeddings: Option<usize>,
quantize: Option<Quantization>,
head_dim: Option<usize>,
model_type: Option<String>,
vision_config: Option<VisionConfig>,
}
impl From<RawConfig> for Config {
@ -154,11 +161,13 @@ impl From<RawConfig> for Config {
}
});
let model_type = other.model_type;
let vision_config = other.vision_config;
Config {
max_position_embeddings,
quantize,
head_dim,
model_type,
vision_config,
}
}
}