Make black formatting happy

This commit is contained in:
Alvaro Bartolome 2024-09-25 19:10:10 +02:00
parent 3b7e010a4c
commit fb28d374e1
No known key found for this signature in database

View File

@ -163,7 +163,9 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemma2Attention(torch.nn.Module): class FlashGemma2Attention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool): def __init__(
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_size = config.head_dim self.head_size = config.head_dim
@ -345,7 +347,9 @@ class Gemma2MLP(nn.Module):
class FlashGemma2Layer(nn.Module): class FlashGemma2Layer(nn.Module):
def __init__(self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool): def __init__(
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
):
super().__init__() super().__init__()
self.self_attn = FlashGemma2Attention( self.self_attn = FlashGemma2Attention(
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",