diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 131c9bb0..17f6a7f1 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -66,6 +66,7 @@ def paged_attention( softcap: Optional[float] = None, ): out = torch.empty_like(query) + input_lengths = seqlen.input_lengths + seqlen.cache_lengths ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, @@ -74,7 +75,7 @@ def paged_attention( kv_head_mapping, softmax_scale, block_tables, - seqlen.input_lengths, + input_lengths, BLOCK_SIZE, max_s, None, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 01d4685a..27e7638a 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -104,7 +104,7 @@ def paged_attention( _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = seqlen.input_lengths + input_lengths = seqlen.input_lengths + seqlen.cache_lengths out = torch.empty_like(query) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7acc723a..1e0e9176 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1917,7 +1917,7 @@ class FlashCausalLM(Model): batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids batch.cache_lengths_tensor += batch.input_lengths_tensor - batch.input_lengths_tensor = accepted_ids + batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) batch.slot_indices += accepted_ids batch.adapter_meta.adapter_indices = next_adapter_indices