diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 93de9648..cc2df11e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,3 +1,5 @@ +import os + import torch import torch.distributed @@ -292,22 +294,16 @@ class FlashRWLargeAttention(torch.nn.Module): if layer_past_present_indices is None: # Copy to layer past layer_past[...] = kv - k, v = kv.split(1, dim=2) # Expand to query shape - k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape( - -1, self.num_groups * self.num_heads, self.head_size - ) - v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape( - -1, self.num_groups * self.num_heads, self.head_size - ) + kv = kv.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) # output attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, - k, - v, + kv[:, :, 0], + kv[:, :, 1], attn_output, cu_seqlens, cu_seqlens, @@ -325,22 +321,16 @@ class FlashRWLargeAttention(torch.nn.Module): else: # Add present to the layer_past tensor at the correct indices layer_past[layer_past_present_indices] = kv - k, v = layer_past.split(1, dim=2) # Expand to query shape - k = k.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape( - -1, self.num_groups * self.num_heads, self.head_size - ) - v = v.expand(-1, self.num_groups, self.num_heads, self.head_size).reshape( - -1, self.num_groups * self.num_heads, self.head_size - ) + kv = layer_past.unsqueeze(2).expand(-1, self.num_groups, self.num_heads, 2, self.head_size).reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) # output attn_output = torch.empty_like(query) # flash attention flash_attn_cuda.fwd( query, - k, - v, + kv[:, :, 0], + kv[:, :, 1], attn_output, cu_seqlens_q, cu_seqlens,