mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
HHachweew
Hack to make other models work.
This commit is contained in:
parent
6bbc843097
commit
6aeb5a73a1
@ -453,6 +453,7 @@ class DbrxAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
None,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -251,6 +251,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
None,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -216,6 +216,7 @@ class MistralAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
None,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -295,6 +295,7 @@ class MixtralAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
None,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -172,6 +172,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
None,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -223,6 +223,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
None,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
@ -346,6 +347,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
None,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -305,6 +305,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
None,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -259,6 +259,7 @@ class Starcoder2Attention(torch.nn.Module):
|
|||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
None,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
@ -71,6 +71,7 @@ def attention(
|
|||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
|
input_lengths = cu_seqlen_k
|
||||||
if SYSTEM == "xpu":
|
if SYSTEM == "xpu":
|
||||||
query = query.contiguous()
|
query = query.contiguous()
|
||||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
|
Loading…
Reference in New Issue
Block a user