enable xpu flashdecoding

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-11-24 21:40:00 -08:00
parent d7c991b0d1
commit d04c86c76c
3 changed files with 13 additions and 5 deletions

View File

@ -34,7 +34,7 @@ def attention(
if ATTENTION == "flashdecoding": if ATTENTION == "flashdecoding":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func( ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out, out,
query, query.contiguous() if query.device.type == "xpu" else query,
kv_cache.key, kv_cache.key,
kv_cache.value, kv_cache.value,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
@ -87,7 +87,7 @@ def paged_attention(
if ATTENTION == "flashdecoding": if ATTENTION == "flashdecoding":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func( ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out, out,
query, query.contiguous() if query.device.type == "xpu" else query,
kv_cache.key, kv_cache.key,
kv_cache.value, kv_cache.value,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,

View File

@ -66,7 +66,9 @@ class KVCache:
else: else:
x = BLOCK_SIZE // element_size x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex": if ATTENTION in {"flashdecoding", "flashinfer"} and not (
SYSTEM == "ipex" and device == torch.device("cpu")
):
self.kv_cache = ( self.kv_cache = (
torch.empty( torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size), (num_blocks, BLOCK_SIZE, num_heads, head_size),
@ -174,7 +176,9 @@ class KVCache:
scalar=True, scalar=True,
)[0] )[0]
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex": if ATTENTION in {"flashdecoding", "flashinfer"} and not (
SYSTEM == "ipex" and key.device == torch.device("cpu")
):
key = key.to(key_cache.dtype) key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype) value = value.to(value_cache.dtype)
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:

View File

@ -4,6 +4,7 @@ from loguru import logger
from typing import Dict, Optional from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from text_generation_server.utils.import_utils import SYSTEM
ATTENTION = os.environ["ATTENTION"] ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
@ -27,9 +28,12 @@ TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1 assert TGI_WIGGLE_ROOM < 1
# This is overridden by the cli # This is overridden by the cli
BLOCK_SIZE: int BLOCK_SIZE: int
if ATTENTION == "flashdecoding": if SYSTEM == "ipex":
BLOCK_SIZE = 16
elif ATTENTION == "flashdecoding":
BLOCK_SIZE = 256 BLOCK_SIZE = 256
elif ATTENTION == "flashinfer": elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1 BLOCK_SIZE = 1