mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
* Add support for FP8 KV cache scales Since FP8 only has limited dynamic range, we can scale keys/values before storing them into the cache (and unscale them in attention). To avoid rescaling the cache as the absmax values change, good scales are usually determined per layer using calibration calibration data and stored in the checkpoint. This change adds support for for using key-value scales and loading them from checkpoints in the two most common formats: - Separate per-layer `k_scale` and `v_scale` scalars. - Per-layer `kv_scale` scalar (older format). Currently, scales are only used with an `float8_e4m3fn` cache. Besides adding support for key/value scales, the `fp8_quantize` function is also extended to support quantization with a kernel vendored from vLLM. This is slightly faster than the PyTorch implementation, but also scales in FP32, potentially improving accuracy. * Update FP8 KV cache test to use checkpoint with scales * `can_scale`: check that the attention is flashinfer
247 lines
8.0 KiB
Python
247 lines
8.0 KiB
Python
from typing import Tuple
|
|
from dataclasses import dataclass, field
|
|
|
|
from loguru import logger
|
|
import torch
|
|
|
|
from text_generation_server.layers.fp8 import fp8_quantize
|
|
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
from text_generation_server.utils.log import log_once
|
|
from text_generation_server.utils.weights import Weights
|
|
|
|
|
|
@dataclass
|
|
class KVScales:
|
|
"""
|
|
Key-value scales for FP8 KV cache.
|
|
|
|
This data class stores key and value scales both as a GPU tensor and
|
|
as a GPU float. This inconvenience is necessary because some functions
|
|
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
|
|
(e.g. flashinfer) take scales as a CPU scalar.
|
|
"""
|
|
|
|
key_scale: torch.Tensor
|
|
value_scale: torch.Tensor
|
|
key_scale_cpu: float = field(init=False)
|
|
value_scale_cpu: float = field(init=False)
|
|
|
|
def __post_init__(self):
|
|
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
|
|
raise ValueError("Key and value scales must be scalar tensors.")
|
|
|
|
self.key_scale_cpu = self.key_scale.item()
|
|
self.value_scale_cpu = self.value_scale.item()
|
|
|
|
|
|
class KVCache:
|
|
"""
|
|
Key-value cache for attention layers.
|
|
"""
|
|
|
|
kv_cache: Tuple[torch.Tensor, torch.Tensor]
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
num_blocks: int,
|
|
num_heads: int,
|
|
head_size: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
):
|
|
"""Construct the key-value cache for a layer."""
|
|
|
|
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn} and (
|
|
ATTENTION != "flashinfer" or SYSTEM != "cuda"
|
|
):
|
|
raise ValueError(
|
|
"FP8 KV cache is currently only supported for flashinfer on CUDA"
|
|
)
|
|
|
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
|
if SYSTEM == "ipex" and device.type == "xpu":
|
|
x = 1
|
|
else:
|
|
x = BLOCK_SIZE // element_size
|
|
|
|
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
|
self.kv_cache = (
|
|
torch.empty(
|
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
torch.empty(
|
|
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
)
|
|
elif SYSTEM == "ipex" and device == torch.device("cpu"):
|
|
self.kv_cache = (
|
|
torch.empty(
|
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
torch.empty(
|
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
)
|
|
else:
|
|
self.kv_cache = (
|
|
torch.zeros(
|
|
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
torch.zeros(
|
|
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
)
|
|
|
|
def can_scale(self, kv_scales: KVScales) -> bool:
|
|
"""Check if the cache can be scaled by the given scales."""
|
|
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
|
|
return False
|
|
elif (
|
|
self.dtype == torch.float8_e4m3fn
|
|
and ATTENTION == "flashinfer"
|
|
and SYSTEM == "cuda"
|
|
):
|
|
log_once(
|
|
logger.info,
|
|
"Using FP8 KV cache scales",
|
|
)
|
|
return True
|
|
else:
|
|
# We have scales, but not the correct FP8 cache type, so warn once.
|
|
log_once(
|
|
logger.info,
|
|
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
|
|
)
|
|
return False
|
|
|
|
@property
|
|
def dtype(self):
|
|
"""Get the data type of the cache."""
|
|
return self.kv_cache[0].dtype
|
|
|
|
@property
|
|
def key(self):
|
|
"""Get the key cache."""
|
|
|
|
return self.kv_cache[0]
|
|
|
|
@property
|
|
def value(self):
|
|
"""Get the value cache."""
|
|
|
|
return self.kv_cache[1]
|
|
|
|
def store(
|
|
self,
|
|
*,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
slots: torch.Tensor,
|
|
kv_scales: KVScales,
|
|
):
|
|
"""Store the key and value at the given slots."""
|
|
|
|
key_cache = self.kv_cache[0]
|
|
value_cache = self.kv_cache[1]
|
|
|
|
if self.can_scale(kv_scales):
|
|
if kv_scales.key_scale_cpu != 1.0:
|
|
key = fp8_quantize(
|
|
key.float(),
|
|
scale=kv_scales.key_scale,
|
|
qdtype=self.dtype,
|
|
scalar=True,
|
|
)[0]
|
|
if kv_scales.value_scale_cpu != 1.0:
|
|
value = fp8_quantize(
|
|
value.float(),
|
|
scale=kv_scales.value_scale,
|
|
qdtype=self.dtype,
|
|
scalar=True,
|
|
)[0]
|
|
|
|
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
|
key = key.to(key_cache.dtype)
|
|
value = value.to(value_cache.dtype)
|
|
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
|
|
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
|
|
# put as raw data instead.
|
|
key_cache = key_cache.view(torch.uint8)
|
|
value_cache = value_cache.view(torch.uint8)
|
|
key = key.view(torch.uint8)
|
|
value = value.view(torch.uint8)
|
|
shape = key_cache.shape
|
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
|
else:
|
|
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
|
|
|
|
|
def paged_reshape_and_cache(
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
slots: torch.Tensor,
|
|
):
|
|
if SYSTEM == "cuda":
|
|
try:
|
|
from vllm._C import cache_ops
|
|
except Exception as e:
|
|
raise ImportError(
|
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
|
)
|
|
cache_ops.reshape_and_cache(
|
|
key, value, key_cache, value_cache, slots, "auto", 1.0
|
|
)
|
|
elif SYSTEM == "rocm":
|
|
try:
|
|
import vllm._custom_ops as ops
|
|
except Exception as e:
|
|
raise ImportError(
|
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
|
)
|
|
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
|
elif SYSTEM == "ipex":
|
|
import intel_extension_for_pytorch as ipex
|
|
|
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
|
key, value, key_cache, value_cache, slots
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported"
|
|
)
|
|
|
|
|
|
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
|
"""Load KV cache scales."""
|
|
|
|
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
|
|
value_scale = key_scale
|
|
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
|
|
f"{prefix}.v_scale"
|
|
):
|
|
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
|
|
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
|
|
elif weights.has_tensor(f"{prefix}.kv_scale"):
|
|
# Fall back to older more coarse-grained scale when available.
|
|
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
|
|
value_scale = key_scale
|
|
|
|
return KVScales(key_scale=key_scale, value_scale=value_scale)
|