mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Fixing dtype + AMD, Ipex targets.
This commit is contained in:
parent
4fa4da3cb6
commit
fa491e730b
@ -66,6 +66,7 @@ def paged_attention(
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
out = torch.empty_like(query)
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
@ -74,7 +75,7 @@ def paged_attention(
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
seqlen.input_lengths,
|
||||
input_lengths,
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
|
@ -104,7 +104,7 @@ def paged_attention(
|
||||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = seqlen.input_lengths
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
|
@ -1917,7 +1917,7 @@ class FlashCausalLM(Model):
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids = next_position_ids + accepted_ids
|
||||
batch.cache_lengths_tensor += batch.input_lengths_tensor
|
||||
batch.input_lengths_tensor = accepted_ids
|
||||
batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32)
|
||||
batch.slot_indices += accepted_ids
|
||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user