mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
launcher: ensure correct detection of Gemma 3 head size
This commit is contained in:
parent
9a8d0462e1
commit
a9b26b221a
@ -260,11 +260,22 @@ struct Config {
|
|||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
fn get_head_dim(&self) -> Option<usize> {
|
fn get_head_dim(&self) -> Option<usize> {
|
||||||
self.head_dim.or_else(|| {
|
if let Some(head_dim) = self.head_dim {
|
||||||
self.text_config
|
return Some(head_dim);
|
||||||
.as_ref()
|
}
|
||||||
.and_then(|text_config| text_config.head_dim)
|
|
||||||
})
|
let text_config = self.text_config.as_ref()?;
|
||||||
|
if let Some(head_size) = text_config.head_dim {
|
||||||
|
return Some(head_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.model_type.as_deref() {
|
||||||
|
// We special-case gemma3 here, since we need flashinfer for
|
||||||
|
// handling bidirectional masks. And flashinfer can only be
|
||||||
|
// used when the head size is known.
|
||||||
|
Some("gemma3") => Some(256),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flop(&self) -> Option<u64> {
|
fn flop(&self) -> Option<u64> {
|
||||||
|
Loading…
Reference in New Issue
Block a user