From 425f348e485ef726192bd3318b4aa54e92cc1212 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Jul 2024 16:36:16 +0200 Subject: [PATCH] Add default to Gemma Causality. --- .../models/custom_modeling/flash_gemma_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 842df0d4..4d731cbf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -162,7 +162,7 @@ def _load_gqa(config, prefix: str, weights): class FlashGemmaAttention(torch.nn.Module): - def __init__(self, prefix: str, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool = True): super().__init__() self.num_heads = config.num_attention_heads self.head_size = config.head_dim