From 12ab24ae647d402f87d54882b18df8f13829e5fa Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 29 May 2023 12:10:17 +0200 Subject: [PATCH] fix normal att --- .../models/custom_modeling/flash_rw_modeling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 2cacc518..c13ffa7c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -125,12 +125,12 @@ class FlashRWAttention(torch.nn.Module): # Split query from key_value query, kv = qkv.split( - [self.head_size * self.num_heads, 2 * self.head_size], dim=1 + [self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv], dim=1 ) # Prepare query and key_value for indexing query = query.view(-1, self.num_heads, self.head_size) - kv = kv.view(-1, 2, 1, self.head_size) + kv = kv.view(-1, 2, self.num_heads_kv, self.head_size) # Inplace rotary self.rotary_emb(query, cos, sin) @@ -141,7 +141,7 @@ class FlashRWAttention(torch.nn.Module): # Copy to layer past layer_past[...] = kv # Expand to query shape - kv = kv.expand(-1, 2, query.shape[1], self.head_size) + kv = kv.expand(-1, 2, self.num_heads, self.head_size) # output attn_output = torch.empty_like(query) @@ -168,7 +168,7 @@ class FlashRWAttention(torch.nn.Module): # Add present to the layer_past tensor at the correct indices layer_past[layer_past_present_indices] = kv # Expand to query shape - kv = layer_past.expand(-1, 2, query.shape[1], self.head_size) + kv = layer_past.expand(-1, 2, self.num_heads, self.head_size) # output attn_output = torch.empty_like(query)