mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
add flashdecoding-ipex
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
ac67673788
commit
d6ac8cdf81
@ -224,7 +224,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
|||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
FROM ${PLATFORM} AS final
|
FROM ${PLATFORM} AS final
|
||||||
ENV ATTENTION=flashdecoding
|
ENV ATTENTION=flashdecoding-ipex
|
||||||
ENV PREFIX_CACHING=1
|
ENV PREFIX_CACHING=1
|
||||||
ENV PREFILL_CHUNKING=1
|
ENV PREFILL_CHUNKING=1
|
||||||
ENV CUDA_GRAPHS=0
|
ENV CUDA_GRAPHS=0
|
||||||
|
@ -144,7 +144,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
|
let fallback_attention = if compute_capability.is_none()
|
||||||
|
|| matches!(compute_capability, Some((major, _)) if major < 8)
|
||||||
|
{
|
||||||
"paged"
|
"paged"
|
||||||
} else {
|
} else {
|
||||||
"flashdecoding"
|
"flashdecoding"
|
||||||
|
@ -31,7 +31,7 @@ def attention(
|
|||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
if ATTENTION == "flashdecoding":
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
out,
|
out,
|
||||||
query.contiguous() if query.device.type == "xpu" else query,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
@ -84,7 +84,7 @@ def paged_attention(
|
|||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
if ATTENTION == "flashdecoding":
|
if ATTENTION == "flashdecoding-ipex":
|
||||||
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
out,
|
out,
|
||||||
query.contiguous() if query.device.type == "xpu" else query,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
|
@ -66,8 +66,8 @@ class KVCache:
|
|||||||
else:
|
else:
|
||||||
x = BLOCK_SIZE // element_size
|
x = BLOCK_SIZE // element_size
|
||||||
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"} and not (
|
if ATTENTION in {"flashdecoding", "flashinfer"} or (
|
||||||
SYSTEM == "ipex" and device == torch.device("cpu")
|
ATTENTION == "flashdecoding-ipex" and device.type == "xpu"
|
||||||
):
|
):
|
||||||
self.kv_cache = (
|
self.kv_cache = (
|
||||||
torch.empty(
|
torch.empty(
|
||||||
@ -82,6 +82,7 @@ class KVCache:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
||||||
|
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
|
||||||
self.kv_cache = (
|
self.kv_cache = (
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
@ -176,9 +177,7 @@ class KVCache:
|
|||||||
scalar=True,
|
scalar=True,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"} and not (
|
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||||
SYSTEM == "ipex" and key.device == torch.device("cpu")
|
|
||||||
):
|
|
||||||
key = key.to(key_cache.dtype)
|
key = key.to(key_cache.dtype)
|
||||||
value = value.to(value_cache.dtype)
|
value = value.to(value_cache.dtype)
|
||||||
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
|
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
|
||||||
@ -191,6 +190,12 @@ class KVCache:
|
|||||||
shape = key_cache.shape
|
shape = key_cache.shape
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
|
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||||
|
key, value, key_cache, value_cache, slots
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ from loguru import logger
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
||||||
ATTENTION = os.environ["ATTENTION"]
|
ATTENTION = os.environ["ATTENTION"]
|
||||||
@ -15,13 +14,17 @@ PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
|
|||||||
}
|
}
|
||||||
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
|
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
|
||||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
_expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"}
|
||||||
assert (
|
assert (
|
||||||
ATTENTION in _expected
|
ATTENTION in _expected
|
||||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||||
|
|
||||||
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
if PREFIX_CACHING and ATTENTION not in {
|
||||||
|
"flashinfer",
|
||||||
|
"flashdecoding",
|
||||||
|
"flashdecoding-ipex",
|
||||||
|
}:
|
||||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
@ -33,12 +36,11 @@ assert TGI_WIGGLE_ROOM < 1
|
|||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
BLOCK_SIZE: int
|
BLOCK_SIZE: int
|
||||||
if ATTENTION == "flashdecoding":
|
if ATTENTION == "flashdecoding":
|
||||||
if SYSTEM == "ipex":
|
BLOCK_SIZE = 256
|
||||||
BLOCK_SIZE = 64
|
|
||||||
else:
|
|
||||||
BLOCK_SIZE = 256
|
|
||||||
elif ATTENTION == "flashinfer":
|
elif ATTENTION == "flashinfer":
|
||||||
BLOCK_SIZE = 1
|
BLOCK_SIZE = 1
|
||||||
|
elif ATTENTION == "flashdecoding-ipex":
|
||||||
|
BLOCK_SIZE = 64
|
||||||
else:
|
else:
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
|
|
||||||
|
@ -79,10 +79,13 @@ class Model(ABC):
|
|||||||
"Prefill chunking will be turned off",
|
"Prefill chunking will be turned off",
|
||||||
)
|
)
|
||||||
support_chunking = False
|
support_chunking = False
|
||||||
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
|
if (
|
||||||
|
ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"]
|
||||||
|
and support_chunking
|
||||||
|
):
|
||||||
log_master(
|
log_master(
|
||||||
logger.warning,
|
logger.warning,
|
||||||
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
|
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.",
|
||||||
)
|
)
|
||||||
support_chunking = False
|
support_chunking = False
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user