diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 56bfb9d0..90b8f7a4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -453,6 +453,7 @@ class DbrxAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index cff4b5d5..994602f2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -251,6 +251,7 @@ class FlashGemmaAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 65043dee..934e8551 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -216,6 +216,7 @@ class MistralAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index be2d6c45..47ab3ccc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -295,6 +295,7 @@ class MixtralAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 3a6d2db5..fac6d3ad 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -172,6 +172,7 @@ class Qwen2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index fa463a19..5c21df56 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -223,6 +223,7 @@ class FlashRWAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) @@ -346,6 +347,7 @@ class FlashRWLargeAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index cfa4243f..8777156a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -305,6 +305,7 @@ class FlashMQAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 3e2ce4f9..7dd41ef9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -259,6 +259,7 @@ class Starcoder2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, + None, input_lengths, max_s, ) diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index e9dd1249..b96ee813 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -71,6 +71,7 @@ def attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + input_lengths = cu_seqlen_k if SYSTEM == "xpu": query = query.contiguous() return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(