From f400f2d58be0cba094782658f5bf342b5dcb39af Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 18 Jul 2023 09:21:22 +0200 Subject: [PATCH] use native grouped attention --- .../models/custom_modeling/flash_rw_modeling.py | 11 ----------- .../custom_modeling/flash_santacoder_modeling.py | 3 --- 2 files changed, 14 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 96fa1b8a..eeac2b9e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -182,10 +182,6 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - if self.num_heads_kv == 1: - # Expand to query shape - kv = kv.expand(-1, 2, self.num_heads, self.head_size) - # flash attention flash_attn_2_cuda.varlen_fwd( query, @@ -313,13 +309,6 @@ class FlashRWLargeAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # Expand to query shape - kv = ( - kv.unsqueeze(2) - .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) - .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) - ) - # flash attention flash_attn_2_cuda.varlen_fwd( query, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 165725c1..76dcf1d5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -271,9 +271,6 @@ class FlashMQAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # Expand from 1 to num_heads - key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) - # flash attention flash_attn_2_cuda.varlen_fwd( query,