mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing falcon.
This commit is contained in:
parent
a76e650283
commit
cf59593454
@ -198,9 +198,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
@ -208,7 +206,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
@ -219,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
Loading…
Reference in New Issue
Block a user