Default value for gemma/gemma2.

This commit is contained in:
Nicolas Patry 2024-07-04 17:17:46 +02:00
parent 425f348e48
commit 4aa0642f4d
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
2 changed files with 2 additions and 2 deletions

View File

@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module):
class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix, config, weights, *, causal: bool = True):
super().__init__()
embed_norm = config.hidden_size**0.5

View File

@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module):
class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix, config, weights, *, causal: bool = True):
super().__init__()
embed_norm = config.hidden_size**0.5