flash decoding

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-11-05 00:48:23 -08:00
parent 780531ec77
commit d7c991b0d1
2 changed files with 69 additions and 33 deletions

View File

@ -1,9 +1,12 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
SUPPORTS_WINDOWING = False
@ -28,6 +31,22 @@ def attention(
out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
if ATTENTION == "flashdecoding":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
causal,
block_tables,
None,
)
else:
ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
@ -64,6 +83,23 @@ def paged_attention(
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
if ATTENTION == "flashdecoding":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
True,
block_tables,
None,
)
else:
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,

View File

@ -66,7 +66,7 @@ class KVCache:
else:
x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"}:
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex":
self.kv_cache = (
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
@ -174,7 +174,7 @@ class KVCache:
scalar=True,
)[0]
if ATTENTION in {"flashdecoding", "flashinfer"}:
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex":
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: