Wrong default.

This commit is contained in:
Nicolas Patry 2024-07-04 17:18:26 +02:00
parent 4aa0642f4d
commit 25c9611c04
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674

View File

@ -162,7 +162,7 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemmaAttention(torch.nn.Module): class FlashGemmaAttention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool = True): def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_size = config.head_dim self.head_size = config.head_dim