From c7b899a438c700e9557410093f1dd37bdcbd59cf Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 30 May 2023 17:09:51 +0200 Subject: [PATCH] black --- .../models/custom_modeling/flash_rw_modeling.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index cc2df11e..545da26a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -295,7 +295,11 @@ class FlashRWLargeAttention(torch.nn.Module): # Copy to layer past layer_past[...] = kv # Expand to query shape - kv = kv.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) + kv = ( + kv.unsqueeze(2) + .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) + .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) + ) # output attn_output = torch.empty_like(query) @@ -322,7 +326,11 @@ class FlashRWLargeAttention(torch.nn.Module): # Add present to the layer_past tensor at the correct indices layer_past[layer_past_present_indices] = kv # Expand to query shape - kv = layer_past.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) + kv = ( + layer_past.unsqueeze(2) + .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) + .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) + ) # output attn_output = torch.empty_like(query)