Making it work on non flash decoding.

This commit is contained in:
Nicolas Patry 2024-05-31 21:41:19 +00:00
parent 4293a12863
commit 66081e6ae7

View File

@ -194,10 +194,10 @@ class FlashLlamaAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
attn_output = torch.empty_like(query)
# flash attention
attention(
query,
@ -211,7 +211,7 @@ class FlashLlamaAttention(torch.nn.Module):
# Decode
else:
attn_output = paged_attention(
None,
attn_output,
query,
kv_cache[0],
kv_cache[1],