diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f33b1622..b522aa07 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -194,10 +194,10 @@ class FlashLlamaAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) # output tensor + attn_output = torch.empty_like(query) # Prefill if cu_seqlen_prefill is not None: - attn_output = torch.empty_like(query) # flash attention attention( query, @@ -211,7 +211,7 @@ class FlashLlamaAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - None, + attn_output, query, kv_cache[0], kv_cache[1],