text-generation-inference/server/text_generation_server/layers/attention/ipex.py
Wang, Yi A ad15a9c0af kv scale in pageattn
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-04-07 22:47:11 -07:00

157 lines
4.6 KiB
Python

import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.layers.attention import Seqlen
from typing import Optional
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
if ATTENTION == "flashdecoding-ipex":
SUPPORTS_WINDOWING = True
else:
SUPPORTS_WINDOWING = False
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
out = torch.empty_like(query)
kv_cache_dtype = "auto"
if kv_cache.key.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if kv_cache.key.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
if ATTENTION == "flashdecoding-ipex":
window_size_right = -1 if window_size_left == -1 else 0
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
causal,
block_tables,
None,
window_size_left=window_size_left,
window_size_right=window_size_right,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
softcap=softcap,
)
else:
if softcap is not None:
raise NotImplementedError(
"softcap is not available in IPEX paged attention"
)
ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,
causal,
False,
None,
)
return out
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
window_size_left: Optional[int] = -1,
):
out = torch.empty_like(query)
kv_cache_dtype = "auto"
if kv_cache.key.dtype == torch.float8_e5m2:
kv_cache_dtype = "fp8_e5m2"
if kv_cache.key.dtype == torch.float8_e4m3fn:
kv_cache_dtype = "fp8_e4m3"
if ATTENTION == "flashdecoding-ipex":
window_size_right = -1 if window_size_left == -1 else 0
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
True,
block_tables,
None,
window_size_left=window_size_left,
window_size_right=window_size_right,
kv_cache_dtype=kv_cache_dtype,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
softcap=softcap,
)
else:
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
if softcap is not None:
raise NotImplementedError(
"softcap is not available in IPEX paged attention"
)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
BLOCK_SIZE,
max_s,
None,
k_scale=kv_scales.key_scale_cpu,
v_scale=kv_scales.value_scale_cpu,
)
return out
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
]