mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing falcon.
This commit is contained in:
parent
a76e650283
commit
cf59593454
@ -198,9 +198,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
@ -208,7 +206,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
@ -219,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
Loading…
Reference in New Issue
Block a user