mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
HHachweew
Hack to make other models work.
This commit is contained in:
parent
6bbc843097
commit
6aeb5a73a1
@ -453,6 +453,7 @@ class DbrxAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
None,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
@ -251,6 +251,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
None,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
@ -216,6 +216,7 @@ class MistralAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
None,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
@ -295,6 +295,7 @@ class MixtralAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
None,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
@ -172,6 +172,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
None,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
@ -223,6 +223,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
None,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
@ -346,6 +347,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
None,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
@ -305,6 +305,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
None,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
@ -259,6 +259,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
None,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
@ -71,6 +71,7 @@ def attention(
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = cu_seqlen_k
|
||||
if SYSTEM == "xpu":
|
||||
query = query.contiguous()
|
||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
|
Loading…
Reference in New Issue
Block a user