diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 4eeea02d..b4e1a6b7 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -47,6 +47,7 @@ struct Config { max_position_embeddings: Option, quantize: Option, head_dim: Option, + model_type: Option, } impl From for Config { @@ -72,10 +73,12 @@ impl From for Config { _ => None, } }); + let model_type = other.model_type; Config { max_position_embeddings, quantize, head_dim, + model_type, } } } @@ -1492,10 +1495,6 @@ fn main() -> Result<(), LauncherError> { let content = std::fs::read_to_string(filename)?; let config: RawConfig = serde_json::from_str(&content)?; - if config.model_type == Some("gemma2".to_string()) { - tracing::info!("Forcing flash decoding because of softcap usage"); - std::env::set_var("ATTENTION", "flashdecoding"); - } let config: Config = config.into(); match config.head_dim { Some(h) if h == 64 || h == 128 || h == 256 => { @@ -1504,6 +1503,15 @@ fn main() -> Result<(), LauncherError> { tracing::info!("Disabling prefix caching because of lora adapters"); std::env::set_var("USE_PREFIX_CACHING", "0"); } + match config.model_type.as_deref() { + Some("gemma2") | Some("falcon") => { + // Required because gemma2 needs bfloat16 which is not supported by + // flashinfer ? + std::env::set_var("USE_PREFIX_CACHING", "0"); + std::env::set_var("ATTENTION", "flashdecoding"); + } + _ => {} + } } _ => { tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching");