mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
flash decoding
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
780531ec77
commit
d7c991b0d1
@ -1,9 +1,12 @@
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import torch
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
from text_generation_server.models.globals import (
|
||||
ATTENTION,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
@ -28,22 +31,38 @@ def attention(
|
||||
out = torch.empty_like(query)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
ipex.llm.functional.varlen_attention(
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
key.contiguous() if key.device.type == "xpu" else key,
|
||||
value.contiguous() if value.device.type == "xpu" else value,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_q,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
if ATTENTION == "flashdecoding":
|
||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
softmax_scale,
|
||||
causal,
|
||||
block_tables,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
ipex.llm.functional.varlen_attention(
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
key.contiguous() if key.device.type == "xpu" else key,
|
||||
value.contiguous() if value.device.type == "xpu" else value,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_q,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
@ -64,20 +83,37 @@ def paged_attention(
|
||||
raise NotImplementedError("softcap is not available in IPEX")
|
||||
|
||||
out = torch.empty_like(query)
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
|
||||
if ATTENTION == "flashdecoding":
|
||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
softmax_scale,
|
||||
True,
|
||||
block_tables,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -66,7 +66,7 @@ class KVCache:
|
||||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex":
|
||||
self.kv_cache = (
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
@ -174,7 +174,7 @@ class KVCache:
|
||||
scalar=True,
|
||||
)[0]
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex":
|
||||
key = key.to(key_cache.dtype)
|
||||
value = value.to(value_cache.dtype)
|
||||
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
|
||||
|
Loading…
Reference in New Issue
Block a user