From a9b26b221a7e35914da91143ed6eedbe96dac41e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 11 Apr 2025 11:56:18 +0000 Subject: [PATCH] launcher: ensure correct detection of Gemma 3 head size --- launcher/src/main.rs | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index acff85730..e481d7eb9 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -260,11 +260,22 @@ struct Config { impl Config { fn get_head_dim(&self) -> Option { - self.head_dim.or_else(|| { - self.text_config - .as_ref() - .and_then(|text_config| text_config.head_dim) - }) + if let Some(head_dim) = self.head_dim { + return Some(head_dim); + } + + let text_config = self.text_config.as_ref()?; + if let Some(head_size) = text_config.head_dim { + return Some(head_size); + } + + match self.model_type.as_deref() { + // We special-case gemma3 here, since we need flashinfer for + // handling bidirectional masks. And flashinfer can only be + // used when the head size is known. + Some("gemma3") => Some(256), + _ => None, + } } fn flop(&self) -> Option {