mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
enable xpu flashdecoding
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
d7c991b0d1
commit
d04c86c76c
@ -34,7 +34,7 @@ def attention(
|
||||
if ATTENTION == "flashdecoding":
|
||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
query,
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
seqlen.cu_seqlen_q,
|
||||
@ -87,7 +87,7 @@ def paged_attention(
|
||||
if ATTENTION == "flashdecoding":
|
||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||
out,
|
||||
query,
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
seqlen.cu_seqlen_q,
|
||||
|
@ -66,7 +66,9 @@ class KVCache:
|
||||
else:
|
||||
x = BLOCK_SIZE // element_size
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex":
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"} and not (
|
||||
SYSTEM == "ipex" and device == torch.device("cpu")
|
||||
):
|
||||
self.kv_cache = (
|
||||
torch.empty(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
@ -174,7 +176,9 @@ class KVCache:
|
||||
scalar=True,
|
||||
)[0]
|
||||
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex":
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"} and not (
|
||||
SYSTEM == "ipex" and key.device == torch.device("cpu")
|
||||
):
|
||||
key = key.to(key_cache.dtype)
|
||||
value = value.to(value_cache.dtype)
|
||||
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
|
||||
|
@ -4,6 +4,7 @@ from loguru import logger
|
||||
from typing import Dict, Optional
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
ATTENTION = os.environ["ATTENTION"]
|
||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||
@ -27,9 +28,12 @@ TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
|
||||
assert TGI_WIGGLE_ROOM > 0
|
||||
assert TGI_WIGGLE_ROOM < 1
|
||||
|
||||
|
||||
# This is overridden by the cli
|
||||
BLOCK_SIZE: int
|
||||
if ATTENTION == "flashdecoding":
|
||||
if SYSTEM == "ipex":
|
||||
BLOCK_SIZE = 16
|
||||
elif ATTENTION == "flashdecoding":
|
||||
BLOCK_SIZE = 256
|
||||
elif ATTENTION == "flashinfer":
|
||||
BLOCK_SIZE = 1
|
||||
|
Loading…
Reference in New Issue
Block a user