mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
enable xpu flashdecoding
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
d7c991b0d1
commit
d04c86c76c
@ -34,7 +34,7 @@ def attention(
|
|||||||
if ATTENTION == "flashdecoding":
|
if ATTENTION == "flashdecoding":
|
||||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
out,
|
out,
|
||||||
query,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
kv_cache.key,
|
kv_cache.key,
|
||||||
kv_cache.value,
|
kv_cache.value,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
@ -87,7 +87,7 @@ def paged_attention(
|
|||||||
if ATTENTION == "flashdecoding":
|
if ATTENTION == "flashdecoding":
|
||||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
out,
|
out,
|
||||||
query,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
kv_cache.key,
|
kv_cache.key,
|
||||||
kv_cache.value,
|
kv_cache.value,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
|
@ -66,7 +66,9 @@ class KVCache:
|
|||||||
else:
|
else:
|
||||||
x = BLOCK_SIZE // element_size
|
x = BLOCK_SIZE // element_size
|
||||||
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex":
|
if ATTENTION in {"flashdecoding", "flashinfer"} and not (
|
||||||
|
SYSTEM == "ipex" and device == torch.device("cpu")
|
||||||
|
):
|
||||||
self.kv_cache = (
|
self.kv_cache = (
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||||
@ -174,7 +176,9 @@ class KVCache:
|
|||||||
scalar=True,
|
scalar=True,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex":
|
if ATTENTION in {"flashdecoding", "flashinfer"} and not (
|
||||||
|
SYSTEM == "ipex" and key.device == torch.device("cpu")
|
||||||
|
):
|
||||||
key = key.to(key_cache.dtype)
|
key = key.to(key_cache.dtype)
|
||||||
value = value.to(value_cache.dtype)
|
value = value.to(value_cache.dtype)
|
||||||
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
|
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
|
||||||
|
@ -4,6 +4,7 @@ from loguru import logger
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
ATTENTION = os.environ["ATTENTION"]
|
ATTENTION = os.environ["ATTENTION"]
|
||||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||||
@ -27,9 +28,12 @@ TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
|
|||||||
assert TGI_WIGGLE_ROOM > 0
|
assert TGI_WIGGLE_ROOM > 0
|
||||||
assert TGI_WIGGLE_ROOM < 1
|
assert TGI_WIGGLE_ROOM < 1
|
||||||
|
|
||||||
|
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
BLOCK_SIZE: int
|
BLOCK_SIZE: int
|
||||||
if ATTENTION == "flashdecoding":
|
if SYSTEM == "ipex":
|
||||||
|
BLOCK_SIZE = 16
|
||||||
|
elif ATTENTION == "flashdecoding":
|
||||||
BLOCK_SIZE = 256
|
BLOCK_SIZE = 256
|
||||||
elif ATTENTION == "flashinfer":
|
elif ATTENTION == "flashinfer":
|
||||||
BLOCK_SIZE = 1
|
BLOCK_SIZE = 1
|
||||||
|
Loading…
Reference in New Issue
Block a user