diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 864bf9b0..b7ce6307 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -162,7 +162,7 @@ def _load_gqa(config, prefix: str, weights): class FlashGemmaAttention(torch.nn.Module): - def __init__(self, prefix: str, config, weights, causal: bool = True): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() self.num_heads = config.num_attention_heads self.head_size = config.head_dim