Fixing falcon.

This commit is contained in:
Nicolas Patry 2024-05-29 18:34:34 +00:00
parent a76e650283
commit cf59593454

View File

@ -198,9 +198,7 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output
attn_output = torch.empty_like(query)
@ -208,7 +206,7 @@ class FlashRWAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -219,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],