diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index cc2df11e..545da26a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -295,7 +295,11 @@ class FlashRWLargeAttention(torch.nn.Module): # Copy to layer past layer_past[...] = kv # Expand to query shape - kv = kv.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) + kv = ( + kv.unsqueeze(2) + .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) + .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) + ) # output attn_output = torch.empty_like(query) @@ -322,7 +326,11 @@ class FlashRWLargeAttention(torch.nn.Module): # Add present to the layer_past tensor at the correct indices layer_past[layer_past_present_indices] = kv # Expand to query shape - kv = layer_past.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) + kv = ( + layer_past.unsqueeze(2) + .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) + .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) + ) # output attn_output = torch.empty_like(query)