mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Add fp8 kv cache for ROCm (#2856)
* add fp8 kv cache for rocm * improvements * update log statement * remove bookkeeping field
This commit is contained in:
parent
de19e7e844
commit
c20025dbf7
@ -52,13 +52,18 @@ class KVCache:
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
"""Construct the key-value cache for a layer."""
|
"""Construct the key-value cache for a layer."""
|
||||||
|
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
||||||
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
|
if not (
|
||||||
ATTENTION != "flashinfer" or SYSTEM != "cuda"
|
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
|
||||||
):
|
or (ATTENTION == "paged" and SYSTEM == "rocm")
|
||||||
raise ValueError(
|
):
|
||||||
"FP8 KV cache is currently only supported for flashinfer on CUDA"
|
raise ValueError(
|
||||||
)
|
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCm. "
|
||||||
|
)
|
||||||
|
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
|
||||||
|
raise ValueError(
|
||||||
|
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
|
||||||
|
)
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
if SYSTEM == "ipex" and device.type == "xpu":
|
if SYSTEM == "ipex" and device.type == "xpu":
|
||||||
@ -113,21 +118,17 @@ class KVCache:
|
|||||||
"""Check if the cache can be scaled by the given scales."""
|
"""Check if the cache can be scaled by the given scales."""
|
||||||
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
|
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
|
||||||
return False
|
return False
|
||||||
elif (
|
elif self.dtype == torch.float8_e4m3fn and (
|
||||||
self.dtype == torch.float8_e4m3fn
|
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
|
||||||
and ATTENTION == "flashinfer"
|
or (ATTENTION == "paged" and SYSTEM == "rocm")
|
||||||
and SYSTEM == "cuda"
|
|
||||||
):
|
):
|
||||||
log_once(
|
log_once(logger.info, "Using FP8 KV cache scales")
|
||||||
logger.info,
|
|
||||||
"Using FP8 KV cache scales",
|
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
# We have scales, but not the correct FP8 cache type, so warn once.
|
# We have scales, but not the correct FP8 cache type, so warn once.
|
||||||
log_once(
|
log_once(
|
||||||
logger.info,
|
logger.info,
|
||||||
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
|
"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm",
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -161,7 +162,7 @@ class KVCache:
|
|||||||
key_cache = self.kv_cache[0]
|
key_cache = self.kv_cache[0]
|
||||||
value_cache = self.kv_cache[1]
|
value_cache = self.kv_cache[1]
|
||||||
|
|
||||||
if self.can_scale(kv_scales):
|
if self.can_scale(kv_scales) and SYSTEM == "cuda":
|
||||||
if kv_scales.key_scale_cpu != 1.0:
|
if kv_scales.key_scale_cpu != 1.0:
|
||||||
key = fp8_quantize(
|
key = fp8_quantize(
|
||||||
key.float(),
|
key.float(),
|
||||||
@ -197,7 +198,15 @@ class KVCache:
|
|||||||
key, value, key_cache, value_cache, slots
|
key, value, key_cache, value_cache, slots
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
paged_reshape_and_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
slots,
|
||||||
|
kv_scales.key_scale_cpu,
|
||||||
|
kv_scales.value_scale_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def paged_reshape_and_cache(
|
def paged_reshape_and_cache(
|
||||||
@ -206,7 +215,10 @@ def paged_reshape_and_cache(
|
|||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
):
|
):
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
try:
|
try:
|
||||||
import attention_kernels
|
import attention_kernels
|
||||||
@ -224,8 +236,15 @@ def paged_reshape_and_cache(
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
if key_cache.dtype == torch.float8_e4m3fn:
|
||||||
|
key_cache = key_cache.view(torch.uint8)
|
||||||
|
value_cache = value_cache.view(torch.uint8)
|
||||||
|
kv_cache_dtype = "fp8"
|
||||||
|
|
||||||
ops.reshape_and_cache(
|
ops.reshape_and_cache(
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0, 1.0
|
key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
|
||||||
)
|
)
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
@ -133,6 +133,15 @@ def paged_attention(
|
|||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
if kv_cache.dtype == torch.float8_e4m3fn:
|
||||||
|
key = kv_cache.key.view(torch.uint8)
|
||||||
|
value = kv_cache.value.view(torch.uint8)
|
||||||
|
kv_cache_dtype = "fp8"
|
||||||
|
else:
|
||||||
|
key = kv_cache.key
|
||||||
|
value = kv_cache.value
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
@ -147,8 +156,8 @@ def paged_attention(
|
|||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
key,
|
||||||
kv_cache.value,
|
value,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
@ -156,9 +165,9 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scales.key_scale_cpu,
|
||||||
1.0,
|
kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
@ -182,8 +191,8 @@ def paged_attention(
|
|||||||
max_logits,
|
max_logits,
|
||||||
tmp_output,
|
tmp_output,
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
key,
|
||||||
kv_cache.value,
|
value,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
@ -191,9 +200,9 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scales.key_scale_cpu,
|
||||||
1.0,
|
kv_scales.value_scale_cpu,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ops.paged_attention_rocm(
|
ops.paged_attention_rocm(
|
||||||
@ -202,8 +211,8 @@ def paged_attention(
|
|||||||
max_logits,
|
max_logits,
|
||||||
tmp_output,
|
tmp_output,
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
key,
|
||||||
kv_cache.value,
|
value,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
@ -211,9 +220,9 @@ def paged_attention(
|
|||||||
block_size,
|
block_size,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
"auto",
|
kv_cache_dtype,
|
||||||
1.0,
|
kv_scales.key_scale_cpu,
|
||||||
1.0,
|
kv_scales.value_scale_cpu,
|
||||||
None,
|
None,
|
||||||
_PARTITION_SIZE,
|
_PARTITION_SIZE,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user