mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Default value for gemma/gemma2.
This commit is contained in:
parent
425f348e48
commit
4aa0642f4d
@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||
|
||||
|
||||
class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, causal: bool):
|
||||
def __init__(self, prefix, config, weights, *, causal: bool = True):
|
||||
super().__init__()
|
||||
|
||||
embed_norm = config.hidden_size**0.5
|
||||
|
@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module):
|
||||
|
||||
|
||||
class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, causal: bool):
|
||||
def __init__(self, prefix, config, weights, *, causal: bool = True):
|
||||
super().__init__()
|
||||
|
||||
embed_norm = config.hidden_size**0.5
|
||||
|
Loading…
Reference in New Issue
Block a user