2024-05-31 15:57:01 +00:00
|
|
|
import intel_extension_for_pytorch as ipex
|
|
|
|
import torch
|
2024-10-24 14:36:18 +00:00
|
|
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
2024-07-02 09:56:07 +00:00
|
|
|
from text_generation_server.layers.attention import Seqlen
|
2024-08-08 16:30:29 +00:00
|
|
|
from typing import Optional
|
2025-01-17 11:04:57 +00:00
|
|
|
from text_generation_server.models.globals import (
|
|
|
|
ATTENTION,
|
|
|
|
BLOCK_SIZE,
|
|
|
|
)
|
2024-05-31 15:57:01 +00:00
|
|
|
|
2025-04-08 05:42:19 +00:00
|
|
|
if ATTENTION == "flashdecoding-ipex":
|
|
|
|
SUPPORTS_WINDOWING = True
|
|
|
|
else:
|
|
|
|
SUPPORTS_WINDOWING = False
|
2024-05-31 15:57:01 +00:00
|
|
|
|
|
|
|
|
|
|
|
def attention(
|
2024-10-17 08:42:52 +00:00
|
|
|
*,
|
|
|
|
query: torch.Tensor,
|
|
|
|
key: torch.Tensor,
|
|
|
|
value: torch.Tensor,
|
|
|
|
kv_cache: KVCache,
|
2024-10-24 14:36:18 +00:00
|
|
|
kv_scales: KVScales,
|
2024-09-05 15:41:39 +00:00
|
|
|
seqlen: Seqlen,
|
|
|
|
block_tables: torch.Tensor,
|
2024-10-17 08:42:52 +00:00
|
|
|
softmax_scale: float,
|
|
|
|
window_size_left: int = -1,
|
|
|
|
causal: bool = True,
|
2024-08-08 16:30:29 +00:00
|
|
|
softcap: Optional[float] = None,
|
2024-05-31 15:57:01 +00:00
|
|
|
):
|
2024-10-17 08:42:52 +00:00
|
|
|
|
|
|
|
out = torch.empty_like(query)
|
2025-04-03 02:29:01 +00:00
|
|
|
kv_cache_dtype = "auto"
|
|
|
|
if kv_cache.key.dtype == torch.float8_e5m2:
|
|
|
|
kv_cache_dtype = "fp8_e5m2"
|
|
|
|
if kv_cache.key.dtype == torch.float8_e4m3fn:
|
|
|
|
kv_cache_dtype = "fp8_e4m3"
|
2024-08-01 15:03:28 +00:00
|
|
|
|
2024-06-10 07:09:50 +00:00
|
|
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
2025-01-17 11:04:57 +00:00
|
|
|
if ATTENTION == "flashdecoding-ipex":
|
2025-04-08 05:42:19 +00:00
|
|
|
window_size_right = -1 if window_size_left == -1 else 0
|
2025-01-17 11:04:57 +00:00
|
|
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
|
|
|
out,
|
|
|
|
query.contiguous() if query.device.type == "xpu" else query,
|
|
|
|
kv_cache.key,
|
|
|
|
kv_cache.value,
|
|
|
|
seqlen.cu_seqlen_q,
|
|
|
|
seqlen.cu_seqlen_k,
|
|
|
|
seqlen.max_q,
|
|
|
|
seqlen.max_k,
|
|
|
|
softmax_scale,
|
|
|
|
causal,
|
|
|
|
block_tables,
|
|
|
|
None,
|
2025-04-08 05:42:19 +00:00
|
|
|
window_size_left=window_size_left,
|
|
|
|
window_size_right=window_size_right,
|
2025-04-03 02:29:01 +00:00
|
|
|
kv_cache_dtype=kv_cache_dtype,
|
2025-03-29 09:31:38 +00:00
|
|
|
k_scale=kv_scales.key_scale_cpu,
|
|
|
|
v_scale=kv_scales.value_scale_cpu,
|
2025-04-08 05:42:19 +00:00
|
|
|
softcap=softcap,
|
2025-01-17 11:04:57 +00:00
|
|
|
)
|
|
|
|
else:
|
2025-04-08 05:42:19 +00:00
|
|
|
if softcap is not None:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"softcap is not available in IPEX paged attention"
|
|
|
|
)
|
2025-01-17 11:04:57 +00:00
|
|
|
ipex.llm.functional.varlen_attention(
|
|
|
|
query.contiguous() if query.device.type == "xpu" else query,
|
|
|
|
key.contiguous() if key.device.type == "xpu" else key,
|
|
|
|
value.contiguous() if value.device.type == "xpu" else value,
|
|
|
|
out,
|
|
|
|
seqlen.cu_seqlen_q,
|
|
|
|
seqlen.cu_seqlen_q,
|
|
|
|
seqlen.max_q,
|
|
|
|
seqlen.max_q,
|
|
|
|
0.0,
|
|
|
|
softmax_scale,
|
|
|
|
False,
|
|
|
|
causal,
|
|
|
|
False,
|
|
|
|
None,
|
|
|
|
)
|
2024-05-31 15:57:01 +00:00
|
|
|
|
2024-08-08 16:30:29 +00:00
|
|
|
return out
|
|
|
|
|
2024-05-31 15:57:01 +00:00
|
|
|
|
|
|
|
def paged_attention(
|
|
|
|
query: torch.Tensor,
|
2024-10-17 08:42:52 +00:00
|
|
|
kv_cache: KVCache,
|
2024-05-31 15:57:01 +00:00
|
|
|
kv_head_mapping: torch.Tensor,
|
|
|
|
softmax_scale: float,
|
|
|
|
block_tables: torch.Tensor,
|
2024-07-02 09:56:07 +00:00
|
|
|
seqlen: Seqlen,
|
2024-05-31 15:57:01 +00:00
|
|
|
max_s: int,
|
2024-10-24 14:36:18 +00:00
|
|
|
*,
|
|
|
|
kv_scales: KVScales,
|
2024-08-08 16:30:29 +00:00
|
|
|
softcap: Optional[float] = None,
|
2025-03-18 09:37:33 +00:00
|
|
|
window_size_left: Optional[int] = -1,
|
2024-05-31 15:57:01 +00:00
|
|
|
):
|
2024-08-01 15:03:28 +00:00
|
|
|
out = torch.empty_like(query)
|
2025-04-03 02:29:01 +00:00
|
|
|
kv_cache_dtype = "auto"
|
|
|
|
if kv_cache.key.dtype == torch.float8_e5m2:
|
|
|
|
kv_cache_dtype = "fp8_e5m2"
|
|
|
|
if kv_cache.key.dtype == torch.float8_e4m3fn:
|
|
|
|
kv_cache_dtype = "fp8_e4m3"
|
2025-01-17 11:04:57 +00:00
|
|
|
if ATTENTION == "flashdecoding-ipex":
|
2025-04-08 05:42:19 +00:00
|
|
|
window_size_right = -1 if window_size_left == -1 else 0
|
2025-01-17 11:04:57 +00:00
|
|
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
|
|
|
out,
|
|
|
|
query.contiguous() if query.device.type == "xpu" else query,
|
|
|
|
kv_cache.key,
|
|
|
|
kv_cache.value,
|
|
|
|
seqlen.cu_seqlen_q,
|
|
|
|
seqlen.cu_seqlen_k,
|
|
|
|
seqlen.max_q,
|
|
|
|
seqlen.max_k,
|
|
|
|
softmax_scale,
|
|
|
|
True,
|
|
|
|
block_tables,
|
|
|
|
None,
|
2025-04-08 05:42:19 +00:00
|
|
|
window_size_left=window_size_left,
|
|
|
|
window_size_right=window_size_right,
|
2025-04-03 02:29:01 +00:00
|
|
|
kv_cache_dtype=kv_cache_dtype,
|
2025-03-29 09:31:38 +00:00
|
|
|
k_scale=kv_scales.key_scale_cpu,
|
|
|
|
v_scale=kv_scales.value_scale_cpu,
|
2025-04-08 05:42:19 +00:00
|
|
|
softcap=softcap,
|
2025-01-17 11:04:57 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
2025-04-08 05:42:19 +00:00
|
|
|
if softcap is not None:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"softcap is not available in IPEX paged attention"
|
|
|
|
)
|
2025-01-17 11:04:57 +00:00
|
|
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
|
|
|
out,
|
|
|
|
query,
|
|
|
|
kv_cache.key,
|
|
|
|
kv_cache.value,
|
|
|
|
kv_head_mapping,
|
|
|
|
softmax_scale,
|
|
|
|
block_tables,
|
|
|
|
input_lengths,
|
|
|
|
BLOCK_SIZE,
|
|
|
|
max_s,
|
|
|
|
None,
|
|
|
|
)
|
2024-07-02 09:56:07 +00:00
|
|
|
return out
|
2024-10-04 15:51:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
"SUPPORTS_WINDOWING",
|
|
|
|
"attention",
|
|
|
|
"paged_attention",
|
|
|
|
]
|