add contiguous

This commit is contained in:
OlivierDehaene 2024-04-04 18:48:58 +02:00
parent 5088005908
commit 4a02d3505f

View File

@ -222,8 +222,8 @@ class FlashCohereAttention(torch.nn.Module):
dim=1,
)
if self.use_qk_norm:
query = self.q_norm(query)
key = self.k_norm(key)
query = self.q_norm(query.contiguous())
key = self.k_norm(key.contiguous())
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_key_value_heads, self.head_size)