text-generation-inference/server/text_generation_server/layers/attention/ipex.py
Mohit Sharma a35fbdb925
Bug Fix: Sliding Window Attention (#3112)
* (fix) sliding window attention

* (fix) flashinfer

* (typo) collection link

* Add window_size_left param ipex rocm

* Update window size rocm flash decoding

* fix: bump snapshots and improve exceed window test case

* feat: add tests for image types and remove alpha from png

* Upgrading `from_env` to get token from file when necessary + fix
pali_gemma.

* fix: add pillow dependency and bump lock+requirements

* fix: bump org name in gemma3 test

* Fix qwen2.

---------

Co-authored-by: drbh <david.richard.holtz@gmail.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
2025-03-18 10:37:33 +01:00

126 lines
3.3 KiB
Python

import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.layers.attention import Seqlen
from typing import Optional
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
SUPPORTS_WINDOWING = False
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
causal,
block_tables,
None,
)
else:
ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,
causal,
False,
None,
)
return out
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
window_size_left: Optional[int] = -1,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
True,
block_tables,
None,
)
else:
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
BLOCK_SIZE,
max_s,
None,
)
return out
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
]