Fixing dtype + AMD, Ipex targets.

This commit is contained in:
Nicolas Patry 2024-10-15 17:56:03 +02:00
parent 4fa4da3cb6
commit fa491e730b
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
3 changed files with 4 additions and 3 deletions

View File

@ -66,6 +66,7 @@ def paged_attention(
softcap: Optional[float] = None,
):
out = torch.empty_like(query)
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
@ -74,7 +75,7 @@ def paged_attention(
kv_head_mapping,
softmax_scale,
block_tables,
seqlen.input_lengths,
input_lengths,
BLOCK_SIZE,
max_s,
None,

View File

@ -104,7 +104,7 @@ def paged_attention(
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = seqlen.input_lengths
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
out = torch.empty_like(query)

View File

@ -1917,7 +1917,7 @@ class FlashCausalLM(Model):
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids
batch.cache_lengths_tensor += batch.input_lengths_tensor
batch.input_lengths_tensor = accepted_ids
batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32)
batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices