add flashdecoding-ipex

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2024-12-19 18:33:25 -08:00
parent ac67673788
commit d6ac8cdf81
6 changed files with 30 additions and 18 deletions

View File

@ -224,7 +224,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final
ENV ATTENTION=flashdecoding
ENV ATTENTION=flashdecoding-ipex
ENV PREFIX_CACHING=1
ENV PREFILL_CHUNKING=1
ENV CUDA_GRAPHS=0

View File

@ -144,7 +144,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
}
}
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
let fallback_attention = if compute_capability.is_none()
|| matches!(compute_capability, Some((major, _)) if major < 8)
{
"paged"
} else {
"flashdecoding"

View File

@ -31,7 +31,7 @@ def attention(
out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
if ATTENTION == "flashdecoding":
if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
@ -84,7 +84,7 @@ def paged_attention(
out = torch.empty_like(query)
if ATTENTION == "flashdecoding":
if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,

View File

@ -66,8 +66,8 @@ class KVCache:
else:
x = BLOCK_SIZE // element_size
if ATTENTION in {"flashdecoding", "flashinfer"} and not (
SYSTEM == "ipex" and device == torch.device("cpu")
if ATTENTION in {"flashdecoding", "flashinfer"} or (
ATTENTION == "flashdecoding-ipex" and device.type == "xpu"
):
self.kv_cache = (
torch.empty(
@ -82,6 +82,7 @@ class KVCache:
),
)
elif SYSTEM == "ipex" and device == torch.device("cpu"):
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
self.kv_cache = (
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
@ -176,9 +177,7 @@ class KVCache:
scalar=True,
)[0]
if ATTENTION in {"flashdecoding", "flashinfer"} and not (
SYSTEM == "ipex" and key.device == torch.device("cpu")
):
if ATTENTION in {"flashdecoding", "flashinfer"}:
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
@ -191,6 +190,12 @@ class KVCache:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
import intel_extension_for_pytorch as ipex
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slots
)
else:
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)

View File

@ -4,7 +4,6 @@ from loguru import logger
from typing import Dict, Optional
from text_generation_server.utils.log import log_master
from text_generation_server.utils.import_utils import SYSTEM
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.environ["ATTENTION"]
@ -15,13 +14,17 @@ PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
}
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "flashdecoding", "flashinfer"}
_expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"}
assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
if PREFIX_CACHING and ATTENTION not in {
"flashinfer",
"flashdecoding",
"flashdecoding-ipex",
}:
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
@ -33,12 +36,11 @@ assert TGI_WIGGLE_ROOM < 1
# This is overridden by the cli
BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
if SYSTEM == "ipex":
BLOCK_SIZE = 64
else:
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1
elif ATTENTION == "flashdecoding-ipex":
BLOCK_SIZE = 64
else:
BLOCK_SIZE = 16

View File

@ -79,10 +79,13 @@ class Model(ABC):
"Prefill chunking will be turned off",
)
support_chunking = False
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
if (
ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"]
and support_chunking
):
log_master(
logger.warning,
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.",
)
support_chunking = False