diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index e0e423b5..50962a1a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -198,9 +198,7 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots - ) + reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output attn_output = torch.empty_like(query) @@ -208,7 +206,7 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn.attention( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), @@ -219,7 +217,7 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - paged_attention.attention( + paged_attention( attn_output, query, kv_cache[0],