mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
remove bookkeeping field
This commit is contained in:
parent
8ffb5b3697
commit
5b10e5bccf
@ -65,11 +65,6 @@ class KVCache:
|
||||
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
|
||||
)
|
||||
|
||||
self.kv_cache_dtype = "auto"
|
||||
if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn:
|
||||
self.kv_cache_dtype = "fp8"
|
||||
dtype = torch.uint8
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
if SYSTEM == "ipex" and device.type == "xpu":
|
||||
x = 1
|
||||
@ -120,12 +115,9 @@ class KVCache:
|
||||
"""Check if the cache can be scaled by the given scales."""
|
||||
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
|
||||
return False
|
||||
elif (
|
||||
self.dtype == torch.float8_e4m3fn
|
||||
and ATTENTION == "flashinfer"
|
||||
and SYSTEM == "cuda"
|
||||
) or (
|
||||
self.kv_cache_dtype == "fp8" and ATTENTION == "paged" and SYSTEM == "rocm"
|
||||
elif self.dtype == torch.float8_e4m3fn and (
|
||||
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
|
||||
or (ATTENTION == "paged" and SYSTEM == "rocm")
|
||||
):
|
||||
log_once(logger.info, "Using FP8 KV cache scales")
|
||||
return True
|
||||
@ -203,7 +195,6 @@ class KVCache:
|
||||
key_cache,
|
||||
value_cache,
|
||||
slots,
|
||||
self.kv_cache_dtype,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
)
|
||||
@ -215,7 +206,6 @@ def paged_reshape_and_cache(
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
kv_cache_dtype: str = "auto",
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
):
|
||||
@ -237,6 +227,13 @@ def paged_reshape_and_cache(
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
|
||||
kv_cache_dtype = "auto"
|
||||
if key_cache.dtype == torch.float8_e4m3fn:
|
||||
key_cache = key_cache.view(torch.uint8)
|
||||
value_cache = value_cache.view(torch.uint8)
|
||||
kv_cache_dtype = "fp8"
|
||||
|
||||
ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale
|
||||
)
|
||||
|
@ -133,6 +133,15 @@ def paged_attention(
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
if kv_cache.dtype == torch.float8_e4m3fn:
|
||||
key = kv_cache.key.view(torch.uint8)
|
||||
value = kv_cache.value.view(torch.uint8)
|
||||
kv_cache_dtype = "fp8"
|
||||
else:
|
||||
key = kv_cache.key
|
||||
value = kv_cache.value
|
||||
kv_cache_dtype = "auto"
|
||||
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
@ -147,8 +156,8 @@ def paged_attention(
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
key,
|
||||
value,
|
||||
num_kv_heads,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
@ -156,7 +165,7 @@ def paged_attention(
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
kv_cache.kv_cache_dtype,
|
||||
kv_cache_dtype,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
)
|
||||
@ -182,8 +191,8 @@ def paged_attention(
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
key,
|
||||
value,
|
||||
num_kv_heads,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
@ -191,7 +200,7 @@ def paged_attention(
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
kv_cache.kv_cache_dtype,
|
||||
kv_cache_dtype,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
)
|
||||
@ -202,8 +211,8 @@ def paged_attention(
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
key,
|
||||
value,
|
||||
num_kv_heads,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
@ -211,7 +220,7 @@ def paged_attention(
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
kv_cache.kv_cache_dtype,
|
||||
kv_cache_dtype,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
None,
|
||||
|
Loading…
Reference in New Issue
Block a user