Add default to Gemma Causality.

This commit is contained in:
Nicolas Patry 2024-07-04 16:36:16 +02:00
parent fc5bfa070a
commit 425f348e48
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674

View File

@ -162,7 +162,7 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemmaAttention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool):
def __init__(self, prefix: str, config, weights, causal: bool = True):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim