From 4aa0642f4d842710f874b569d20c4aa21f767dab Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Jul 2024 17:17:46 +0200 Subject: [PATCH] Default value for gemma/gemma2. --- .../models/custom_modeling/flash_gemma2_modeling.py | 2 +- .../models/custom_modeling/flash_gemma_modeling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index cfa6b2fe..625baa91 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module): class FlashGemma2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 4d731cbf..864bf9b0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5