diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index a11ab1d3..eca01bbb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -120,7 +120,7 @@ class MistralAttention(torch.nn.Module): if hasattr(config, "head_dim"): self.head_size = config.head_dim else: - self.head_size = hidden_size // num_attention_heads + self.head_size = self.hidden_size // self.num_heads self.rotary_emb = PositionRotaryEmbedding.static( config=config,