mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 14:22:08 +00:00
* Add support for FP8 KV cache scales Since FP8 only has limited dynamic range, we can scale keys/values before storing them into the cache (and unscale them in attention). To avoid rescaling the cache as the absmax values change, good scales are usually determined per layer using calibration calibration data and stored in the checkpoint. This change adds support for for using key-value scales and loading them from checkpoints in the two most common formats: - Separate per-layer `k_scale` and `v_scale` scalars. - Per-layer `kv_scale` scalar (older format). Currently, scales are only used with an `float8_e4m3fn` cache. Besides adding support for key/value scales, the `fp8_quantize` function is also extended to support quantization with a kernel vendored from vLLM. This is slightly faster than the PyTorch implementation, but also scales in FP32, potentially improving accuracy. * Update FP8 KV cache test to use checkpoint with scales * `can_scale`: check that the attention is flashinfer
89 lines
2.2 KiB
Python
89 lines
2.2 KiB
Python
import intel_extension_for_pytorch as ipex
|
|
import torch
|
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
|
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
|
from text_generation_server.layers.attention import Seqlen
|
|
from typing import Optional
|
|
|
|
SUPPORTS_WINDOWING = False
|
|
|
|
|
|
def attention(
|
|
*,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: KVCache,
|
|
kv_scales: KVScales,
|
|
seqlen: Seqlen,
|
|
block_tables: torch.Tensor,
|
|
softmax_scale: float,
|
|
window_size_left: int = -1,
|
|
causal: bool = True,
|
|
softcap: Optional[float] = None,
|
|
):
|
|
if softcap is not None:
|
|
raise NotImplementedError("softcap is not available in IPEX")
|
|
|
|
out = torch.empty_like(query)
|
|
|
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
|
ipex.llm.functional.varlen_attention(
|
|
query.contiguous() if query.device.type == "xpu" else query,
|
|
key.contiguous() if key.device.type == "xpu" else key,
|
|
value.contiguous() if value.device.type == "xpu" else value,
|
|
out,
|
|
seqlen.cu_seqlen_q,
|
|
seqlen.cu_seqlen_q,
|
|
seqlen.max_q,
|
|
seqlen.max_q,
|
|
0.0,
|
|
softmax_scale,
|
|
False,
|
|
causal,
|
|
False,
|
|
None,
|
|
)
|
|
|
|
return out
|
|
|
|
|
|
def paged_attention(
|
|
query: torch.Tensor,
|
|
kv_cache: KVCache,
|
|
kv_head_mapping: torch.Tensor,
|
|
softmax_scale: float,
|
|
block_tables: torch.Tensor,
|
|
seqlen: Seqlen,
|
|
max_s: int,
|
|
*,
|
|
kv_scales: KVScales,
|
|
softcap: Optional[float] = None,
|
|
):
|
|
if softcap is not None:
|
|
raise NotImplementedError("softcap is not available in IPEX")
|
|
|
|
out = torch.empty_like(query)
|
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
|
out,
|
|
query,
|
|
kv_cache.key,
|
|
kv_cache.value,
|
|
kv_head_mapping,
|
|
softmax_scale,
|
|
block_tables,
|
|
input_lengths,
|
|
BLOCK_SIZE,
|
|
max_s,
|
|
None,
|
|
)
|
|
return out
|
|
|
|
|
|
__all__ = [
|
|
"SUPPORTS_WINDOWING",
|
|
"attention",
|
|
"paged_attention",
|
|
]
|