HHachweew

Hack to make other models work.
This commit is contained in:
Nicolas Patry 2024-05-29 10:52:09 +00:00
parent 6bbc843097
commit 6aeb5a73a1
9 changed files with 10 additions and 0 deletions

View File

@ -453,6 +453,7 @@ class DbrxAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)

View File

@ -251,6 +251,7 @@ class FlashGemmaAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)

View File

@ -216,6 +216,7 @@ class MistralAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)

View File

@ -295,6 +295,7 @@ class MixtralAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)

View File

@ -172,6 +172,7 @@ class Qwen2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)

View File

@ -223,6 +223,7 @@ class FlashRWAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
@ -346,6 +347,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)

View File

@ -305,6 +305,7 @@ class FlashMQAttention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)

View File

@ -259,6 +259,7 @@ class Starcoder2Attention(torch.nn.Module):
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)

View File

@ -71,6 +71,7 @@ def attention(
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k
if SYSTEM == "xpu":
query = query.contiguous()
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(