mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
Fixing the default for vlm.
This commit is contained in:
parent
27b566baa8
commit
e30fb25444
@ -606,8 +606,8 @@ mod tests {
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 6, 7]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]);
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);
|
||||
assert_eq!(allocation.prefix_len, 4);
|
||||
}
|
||||
|
||||
@ -618,15 +618,12 @@ mod tests {
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||
assert_eq!(allocation.prefix_len, 0);
|
||||
cache.free(
|
||||
allocation.blocks[..allocation.blocks.len() - 1].to_vec(),
|
||||
allocation.allocation_id,
|
||||
);
|
||||
cache.free(allocation.blocks.clone(), allocation.allocation_id);
|
||||
|
||||
let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 6, 7]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 12, 13, 14, 15]);
|
||||
assert_eq!(allocation.prefix_len, 4);
|
||||
let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();
|
||||
assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);
|
||||
assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);
|
||||
assert_eq!(allocation.prefix_len, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -74,16 +74,19 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
tracing::info!("Disabling prefix caching because of lora adapters");
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
if config.vision_config.is_some() && prefix_caching.is_none() {
|
||||
tracing::info!("Disabling prefix caching because of VLM model");
|
||||
prefix_caching = Some("0".to_string());
|
||||
}
|
||||
match config.model_type.as_deref() {
|
||||
Some("gemma2") | Some("falcon") | Some("deepseek_v2") => {
|
||||
// Required because gemma2 needs bfloat16 which is not supported by
|
||||
// flashinfer ?
|
||||
if prefix_caching.is_none() {
|
||||
if attention.is_none() {
|
||||
tracing::info!(
|
||||
"Forcing flash decoding because model {} requires it",
|
||||
config.model_type.as_ref().unwrap()
|
||||
);
|
||||
prefix_caching = Some("1".to_string());
|
||||
attention = Some("flashdecoding".to_string());
|
||||
}
|
||||
}
|
||||
@ -91,9 +94,8 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if prefix_caching.is_none() {
|
||||
if attention.is_none() {
|
||||
tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");
|
||||
prefix_caching = Some("1".to_string());
|
||||
attention = Some("flashdecoding".to_string());
|
||||
}
|
||||
}
|
||||
@ -115,6 +117,7 @@ struct RawConfig {
|
||||
hidden_size: Option<usize>,
|
||||
num_attention_heads: Option<usize>,
|
||||
head_dim: Option<usize>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@ -122,12 +125,16 @@ struct QuantizationConfig {
|
||||
quant_method: Option<Quantization>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct VisionConfig {}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Config {
|
||||
max_position_embeddings: Option<usize>,
|
||||
quantize: Option<Quantization>,
|
||||
head_dim: Option<usize>,
|
||||
model_type: Option<String>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
}
|
||||
|
||||
impl From<RawConfig> for Config {
|
||||
@ -154,11 +161,13 @@ impl From<RawConfig> for Config {
|
||||
}
|
||||
});
|
||||
let model_type = other.model_type;
|
||||
let vision_config = other.vision_config;
|
||||
Config {
|
||||
max_position_embeddings,
|
||||
quantize,
|
||||
head_dim,
|
||||
model_type,
|
||||
vision_config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user