From 25c9611c04b70859dd51350d6721e5e07f34db4a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Jul 2024 17:18:26 +0200 Subject: [PATCH] Wrong default. --- .../models/custom_modeling/flash_gemma_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 864bf9b0..b7ce6307 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -162,7 +162,7 @@ def _load_gqa(config, prefix: str, weights): class FlashGemmaAttention(torch.nn.Module): - def __init__(self, prefix: str, config, weights, causal: bool = True): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() self.num_heads = config.num_attention_heads self.head_size = config.head_dim