From 66081e6ae71712ae6fe3b816a39349d5e045eafa Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 31 May 2024 21:41:19 +0000 Subject: [PATCH] Making it work on non flash decoding. --- .../models/custom_modeling/flash_llama_modeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f33b1622..b522aa07 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -194,10 +194,10 @@ class FlashLlamaAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor + attn_output = torch.empty_like(query) # Prefill if cu_seqlen_prefill is not None: - attn_output = torch.empty_like(query) # flash attention attention( query, @@ -211,7 +211,7 @@ class FlashLlamaAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - None, + attn_output, query, kv_cache[0], kv_cache[1],