mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
flash decoding
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
780531ec77
commit
d7c991b0d1
@ -1,9 +1,12 @@
|
|||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from text_generation_server.models.globals import (
|
||||||
|
ATTENTION,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
@ -28,22 +31,38 @@ def attention(
|
|||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
ipex.llm.functional.varlen_attention(
|
if ATTENTION == "flashdecoding":
|
||||||
query.contiguous() if query.device.type == "xpu" else query,
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
key.contiguous() if key.device.type == "xpu" else key,
|
out,
|
||||||
value.contiguous() if value.device.type == "xpu" else value,
|
query,
|
||||||
out,
|
kv_cache.key,
|
||||||
seqlen.cu_seqlen_q,
|
kv_cache.value,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.max_q,
|
seqlen.cu_seqlen_k,
|
||||||
seqlen.max_q,
|
seqlen.max_q,
|
||||||
0.0,
|
seqlen.max_k,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
causal,
|
||||||
causal,
|
block_tables,
|
||||||
False,
|
None,
|
||||||
None,
|
)
|
||||||
)
|
else:
|
||||||
|
ipex.llm.functional.varlen_attention(
|
||||||
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
|
key.contiguous() if key.device.type == "xpu" else key,
|
||||||
|
value.contiguous() if value.device.type == "xpu" else value,
|
||||||
|
out,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.cu_seqlen_q,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_q,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -64,20 +83,37 @@ def paged_attention(
|
|||||||
raise NotImplementedError("softcap is not available in IPEX")
|
raise NotImplementedError("softcap is not available in IPEX")
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
|
||||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
if ATTENTION == "flashdecoding":
|
||||||
out,
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
query,
|
out,
|
||||||
kv_cache.key,
|
query,
|
||||||
kv_cache.value,
|
kv_cache.key,
|
||||||
kv_head_mapping,
|
kv_cache.value,
|
||||||
softmax_scale,
|
seqlen.cu_seqlen_q,
|
||||||
block_tables,
|
seqlen.cu_seqlen_k,
|
||||||
input_lengths,
|
seqlen.max_q,
|
||||||
BLOCK_SIZE,
|
seqlen.max_k,
|
||||||
max_s,
|
softmax_scale,
|
||||||
None,
|
True,
|
||||||
)
|
block_tables,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
kv_cache.key,
|
||||||
|
kv_cache.value,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ class KVCache:
|
|||||||
else:
|
else:
|
||||||
x = BLOCK_SIZE // element_size
|
x = BLOCK_SIZE // element_size
|
||||||
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex":
|
||||||
self.kv_cache = (
|
self.kv_cache = (
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
@ -174,7 +174,7 @@ class KVCache:
|
|||||||
scalar=True,
|
scalar=True,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex":
|
||||||
key = key.to(key_cache.dtype)
|
key = key.to(key_cache.dtype)
|
||||||
value = value.to(value_cache.dtype)
|
value = value.to(value_cache.dtype)
|
||||||
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
|
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
|
||||||
|
Loading…
Reference in New Issue
Block a user