diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index cfa6b2fe..625baa91 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module): class FlashGemma2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 4d731cbf..864bf9b0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5