mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Making it work on non flash decoding.
This commit is contained in:
parent
4293a12863
commit
66081e6ae7
@ -194,10 +194,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
@ -211,7 +211,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
None,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
|
Loading…
Reference in New Issue
Block a user