text-generation-inference/server/text_generation_server/layers/attention/ipex.py
Nicolas Patry 9e2fdf57c0
Removing IPEX_AVAIL. (#2115)
* Removing IPEX_AVAIL.

Chose to unify CPU and XPU under `ipex`. Most code is exactly similar
except for a very few spots.

The biggest number of spots is the kv-cache layout and the flash_xxx.py
files.
Since those files should be removed soon and factored away, we should
not need them.

* Forgot a few places.

* Unrelated change.

* Fixing HF_TOKEN.

* HF_TOKEN
2024-06-25 13:20:57 +02:00

73 lines
1.5 KiB
Python

import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
SUPPORTS_WINDOWING = False
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
def paged_attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
):
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
BLOCK_SIZE,
max_s,
None,
)