launcher: ensure correct detection of Gemma 3 head size

This commit is contained in:
Daniël de Kok 2025-04-11 11:56:18 +00:00
parent 9a8d0462e1
commit a9b26b221a

View File

@ -260,11 +260,22 @@ struct Config {
impl Config {
fn get_head_dim(&self) -> Option<usize> {
self.head_dim.or_else(|| {
self.text_config
.as_ref()
.and_then(|text_config| text_config.head_dim)
})
if let Some(head_dim) = self.head_dim {
return Some(head_dim);
}
let text_config = self.text_config.as_ref()?;
if let Some(head_size) = text_config.head_dim {
return Some(head_size);
}
match self.model_type.as_deref() {
// We special-case gemma3 here, since we need flashinfer for
// handling bidirectional masks. And flashinfer can only be
// used when the head size is known.
Some("gemma3") => Some(256),
_ => None,
}
}
fn flop(&self) -> Option<u64> {