mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 13:52:07 +00:00
launcher: ensure correct detection of Gemma 3 head size
This commit is contained in:
parent
9a8d0462e1
commit
a9b26b221a
@ -260,11 +260,22 @@ struct Config {
|
||||
|
||||
impl Config {
|
||||
fn get_head_dim(&self) -> Option<usize> {
|
||||
self.head_dim.or_else(|| {
|
||||
self.text_config
|
||||
.as_ref()
|
||||
.and_then(|text_config| text_config.head_dim)
|
||||
})
|
||||
if let Some(head_dim) = self.head_dim {
|
||||
return Some(head_dim);
|
||||
}
|
||||
|
||||
let text_config = self.text_config.as_ref()?;
|
||||
if let Some(head_size) = text_config.head_dim {
|
||||
return Some(head_size);
|
||||
}
|
||||
|
||||
match self.model_type.as_deref() {
|
||||
// We special-case gemma3 here, since we need flashinfer for
|
||||
// handling bidirectional masks. And flashinfer can only be
|
||||
// used when the head size is known.
|
||||
Some("gemma3") => Some(256),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn flop(&self) -> Option<u64> {
|
||||
|
Loading…
Reference in New Issue
Block a user