mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Making it work on non flash decoding.
This commit is contained in:
parent
4293a12863
commit
66081e6ae7
@ -194,10 +194,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
attention(
|
||||
query,
|
||||
@ -211,7 +211,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
None,
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
Loading…
Reference in New Issue
Block a user