mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Make black
formatting happy
This commit is contained in:
parent
3b7e010a4c
commit
fb28d374e1
@ -163,7 +163,9 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemma2Attention(torch.nn.Module):
|
class FlashGemma2Attention(torch.nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool):
|
def __init__(
|
||||||
|
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_size = config.head_dim
|
self.head_size = config.head_dim
|
||||||
@ -345,7 +347,9 @@ class Gemma2MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashGemma2Layer(nn.Module):
|
class FlashGemma2Layer(nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool):
|
def __init__(
|
||||||
|
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashGemma2Attention(
|
self.self_attn = FlashGemma2Attention(
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
|
Loading…
Reference in New Issue
Block a user