flash decoding

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-11-05 00:48:23 -08:00
parent 780531ec77
commit d7c991b0d1
2 changed files with 69 additions and 33 deletions

View File

@ -1,9 +1,12 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
SUPPORTS_WINDOWING = False
@ -28,22 +31,38 @@ def attention(
out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,
causal,
False,
None,
)
if ATTENTION == "flashdecoding":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
causal,
block_tables,
None,
)
else:
ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,
causal,
False,
None,
)
return out
@ -64,20 +83,37 @@ def paged_attention(
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
BLOCK_SIZE,
max_s,
None,
)
if ATTENTION == "flashdecoding":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
True,
block_tables,
None,
)
else:
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
BLOCK_SIZE,
max_s,
None,
)
return out

View File

@ -66,7 +66,7 @@ class KVCache:
else:
x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"}:
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex":
self.kv_cache = (
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
@ -174,7 +174,7 @@ class KVCache:
scalar=True,
)[0]
if ATTENTION in {"flashdecoding", "flashinfer"}:
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex":
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: