Fixing falcon.

This commit is contained in:
Nicolas Patry 2024-05-29 18:34:34 +00:00
parent a76e650283
commit cf59593454

View File

@ -198,9 +198,7 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache( reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output # output
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
@ -208,7 +206,7 @@ class FlashRWAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
flash_attn.attention( attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
@ -219,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
) )
# Decode # Decode
else: else:
paged_attention.attention( paged_attention(
attn_output, attn_output,
query, query,
kv_cache[0], kv_cache[0],