mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing dtype + AMD, Ipex targets.
This commit is contained in:
parent
4fa4da3cb6
commit
fa491e730b
@ -66,6 +66,7 @@ def paged_attention(
|
|||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
@ -74,7 +75,7 @@ def paged_attention(
|
|||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
seqlen.input_lengths,
|
input_lengths,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
|
@ -104,7 +104,7 @@ def paged_attention(
|
|||||||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||||
|
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
input_lengths = seqlen.input_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
@ -1917,7 +1917,7 @@ class FlashCausalLM(Model):
|
|||||||
batch.speculative_ids = speculative_ids
|
batch.speculative_ids = speculative_ids
|
||||||
batch.position_ids = next_position_ids + accepted_ids
|
batch.position_ids = next_position_ids + accepted_ids
|
||||||
batch.cache_lengths_tensor += batch.input_lengths_tensor
|
batch.cache_lengths_tensor += batch.input_lengths_tensor
|
||||||
batch.input_lengths_tensor = accepted_ids
|
batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32)
|
||||||
batch.slot_indices += accepted_ids
|
batch.slot_indices += accepted_ids
|
||||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user