From fb28d374e1d5ee2adc2da2c522be783a6975599c Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Wed, 25 Sep 2024 19:10:10 +0200 Subject: [PATCH] Make `black` formatting happy --- .../models/custom_modeling/flash_gemma2_modeling.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 11791cbf..887e187e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -163,7 +163,9 @@ def _load_gqa(config, prefix: str, weights): class FlashGemma2Attention(torch.nn.Module): - def __init__(self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool): + def __init__( + self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool + ): super().__init__() self.num_heads = config.num_attention_heads self.head_size = config.head_dim @@ -345,7 +347,9 @@ class Gemma2MLP(nn.Module): class FlashGemma2Layer(nn.Module): - def __init__(self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool): + def __init__( + self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool + ): super().__init__() self.self_attn = FlashGemma2Attention( prefix=f"{prefix}.self_attn",