2024-10-04 15:51:48 +00:00
from typing import Tuple
2024-10-24 14:36:18 +00:00
from dataclasses import dataclass , field
2024-10-04 15:51:48 +00:00
2024-10-24 14:36:18 +00:00
from loguru import logger
2024-10-04 15:51:48 +00:00
import torch
2024-10-24 14:36:18 +00:00
from text_generation_server . layers . fp8 import fp8_quantize
2024-10-04 15:51:48 +00:00
from text_generation_server . models . globals import ATTENTION , BLOCK_SIZE
from text_generation_server . utils . import_utils import SYSTEM
2025-02-10 18:19:25 +00:00
from text_generation_server . utils . kernels import load_kernel
2024-10-24 14:36:18 +00:00
from text_generation_server . utils . log import log_once
from text_generation_server . utils . weights import Weights
2025-02-10 18:19:25 +00:00
if SYSTEM == " cuda " :
try :
paged_attention = load_kernel (
module = " paged_attention " , repo_id = " kernels-community/paged-attention "
)
except Exception as e :
raise ImportError (
f " Could not import attention kernels. Make sure your installation is correct. Complete error: { e } "
)
else :
paged_attention = None
2024-10-24 14:36:18 +00:00
@dataclass
class KVScales :
"""
Key - value scales for FP8 KV cache .
This data class stores key and value scales both as a GPU tensor and
as a GPU float . This inconvenience is necessary because some functions
( e . g . scaling kernels ) take scales as a GPU tensor , whereas others
( e . g . flashinfer ) take scales as a CPU scalar .
"""
key_scale : torch . Tensor
value_scale : torch . Tensor
key_scale_cpu : float = field ( init = False )
value_scale_cpu : float = field ( init = False )
def __post_init__ ( self ) :
if self . key_scale . numel ( ) != 1 or self . value_scale . numel ( ) != 1 :
raise ValueError ( " Key and value scales must be scalar tensors. " )
self . key_scale_cpu = self . key_scale . item ( )
self . value_scale_cpu = self . value_scale . item ( )
2024-10-04 15:51:48 +00:00
class KVCache :
"""
Key - value cache for attention layers .
"""
kv_cache : Tuple [ torch . Tensor , torch . Tensor ]
def __init__ (
self ,
* ,
num_blocks : int ,
num_heads : int ,
head_size : int ,
dtype : torch . dtype ,
device : torch . device ,
) :
""" Construct the key-value cache for a layer. """
2025-01-17 13:13:29 +00:00
if dtype in { torch . float8_e5m2 , torch . float8_e4m3fn } :
if not (
( ATTENTION == " flashinfer " and SYSTEM == " cuda " )
2025-03-29 09:31:38 +00:00
or ( ATTENTION == " paged " and SYSTEM in ( " cuda " , " rocm " , " ipex " ) )
or ( ATTENTION == " flashdecoding-ipex " )
2025-01-17 13:13:29 +00:00
) :
raise ValueError (
2025-03-29 09:31:38 +00:00
" FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA, ROCm and INTEL IPEX and flashdecoding in Intel IPEX "
2025-01-17 13:13:29 +00:00
)
if SYSTEM == " rocm " and dtype == torch . float8_e5m2 :
raise ValueError (
" float8_e5m2 FP8 KV cache is not supported on AMD ROCm "
)
2025-03-29 09:31:38 +00:00
if device . type == " cpu " and dtype == torch . float8_e4m3fn :
raise ValueError (
" float8_e4m3fn FP8 KV cache is not supported on Intel IPEX CPU "
)
2024-10-04 15:51:48 +00:00
element_size = torch . tensor ( [ ] , dtype = dtype ) . element_size ( )
if SYSTEM == " ipex " and device . type == " xpu " :
x = 1
else :
x = BLOCK_SIZE / / element_size
2025-01-17 11:04:57 +00:00
if ATTENTION in { " flashdecoding " , " flashinfer " } or (
ATTENTION == " flashdecoding-ipex " and device . type == " xpu "
) :
2024-10-04 15:51:48 +00:00
self . kv_cache = (
torch . empty (
( num_blocks , BLOCK_SIZE , num_heads , head_size ) ,
dtype = dtype ,
device = device ,
) ,
torch . empty (
( num_blocks , BLOCK_SIZE , num_heads , head_size ) ,
dtype = dtype ,
device = device ,
) ,
)
elif SYSTEM == " ipex " and device == torch . device ( " cpu " ) :
2025-01-17 11:04:57 +00:00
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
2024-10-04 15:51:48 +00:00
self . kv_cache = (
torch . empty (
( num_blocks , num_heads , BLOCK_SIZE , head_size ) ,
dtype = dtype ,
device = device ,
) ,
torch . empty (
( num_blocks , num_heads , BLOCK_SIZE , head_size ) ,
dtype = dtype ,
device = device ,
) ,
)
else :
self . kv_cache = (
torch . zeros (
( num_blocks , num_heads , head_size / / x , BLOCK_SIZE , x ) ,
dtype = dtype ,
device = device ,
) ,
torch . zeros (
( num_blocks , num_heads , head_size , BLOCK_SIZE ) ,
dtype = dtype ,
device = device ,
) ,
)
2024-10-24 14:36:18 +00:00
def can_scale ( self , kv_scales : KVScales ) - > bool :
""" Check if the cache can be scaled by the given scales. """
if kv_scales . key_scale_cpu == 1.0 and kv_scales . value_scale_cpu == 1.0 :
return False
2025-01-17 13:13:29 +00:00
elif self . dtype == torch . float8_e4m3fn and (
2025-02-10 18:19:25 +00:00
( ATTENTION in ( " paged " , " flashinfer " ) and SYSTEM == " cuda " )
2025-03-29 09:31:38 +00:00
or ( ATTENTION == " paged " and SYSTEM in [ " rocm " , " ipex " ] )
or ( ATTENTION == " flashdecoding-ipex " )
2024-10-24 14:36:18 +00:00
) :
2025-01-17 13:13:29 +00:00
log_once ( logger . info , " Using FP8 KV cache scales " )
2024-10-24 14:36:18 +00:00
return True
else :
# We have scales, but not the correct FP8 cache type, so warn once.
log_once (
logger . info ,
2025-03-29 09:31:38 +00:00
" Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm/IPEX and flashdecoding on IPEX " ,
2024-10-24 14:36:18 +00:00
)
return False
@property
def dtype ( self ) :
""" Get the data type of the cache. """
return self . kv_cache [ 0 ] . dtype
2024-10-04 15:51:48 +00:00
@property
def key ( self ) :
""" Get the key cache. """
return self . kv_cache [ 0 ]
@property
def value ( self ) :
""" Get the value cache. """
return self . kv_cache [ 1 ]
def store (
self ,
* ,
key : torch . Tensor ,
value : torch . Tensor ,
slots : torch . Tensor ,
2024-10-24 14:36:18 +00:00
kv_scales : KVScales ,
2024-10-04 15:51:48 +00:00
) :
""" Store the key and value at the given slots. """
key_cache = self . kv_cache [ 0 ]
value_cache = self . kv_cache [ 1 ]
2025-01-17 13:13:29 +00:00
if self . can_scale ( kv_scales ) and SYSTEM == " cuda " :
2024-10-24 14:36:18 +00:00
if kv_scales . key_scale_cpu != 1.0 :
key = fp8_quantize (
key . float ( ) ,
scale = kv_scales . key_scale ,
qdtype = self . dtype ,
scalar = True ,
) [ 0 ]
if kv_scales . value_scale_cpu != 1.0 :
value = fp8_quantize (
value . float ( ) ,
scale = kv_scales . value_scale ,
qdtype = self . dtype ,
scalar = True ,
) [ 0 ]
2024-10-04 15:51:48 +00:00
if ATTENTION in { " flashdecoding " , " flashinfer " } :
key = key . to ( key_cache . dtype )
value = value . to ( value_cache . dtype )
2024-10-24 14:36:18 +00:00
if key_cache . dtype in { torch . float8_e4m3fn , torch . float8_e5m2 } :
2024-10-17 08:42:16 +00:00
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
2024-10-04 15:51:48 +00:00
# put as raw data instead.
key_cache = key_cache . view ( torch . uint8 )
value_cache = value_cache . view ( torch . uint8 )
key = key . view ( torch . uint8 )
value = value . view ( torch . uint8 )
shape = key_cache . shape
key_cache . view ( - 1 , shape [ - 2 ] , shape [ - 1 ] ) [ slots ] = key
value_cache . view ( - 1 , shape [ - 2 ] , shape [ - 1 ] ) [ slots ] = value
2025-01-17 11:04:57 +00:00
elif ATTENTION == " flashdecoding-ipex " and key . device . type == " xpu " :
import intel_extension_for_pytorch as ipex
2025-04-03 02:29:01 +00:00
kv_cache_dtype = " auto "
if key_cache . dtype == torch . float8_e5m2 :
kv_cache_dtype = " fp8_e5m2 "
if key_cache . dtype == torch . float8_e4m3fn :
kv_cache_dtype = " fp8_e4m3 "
2025-01-17 11:04:57 +00:00
ipex . llm . modules . PagedAttention . reshape_and_cache_flash (
2025-03-29 09:31:38 +00:00
key ,
value ,
key_cache ,
value_cache ,
slots ,
2025-04-03 02:29:01 +00:00
kv_cache_dtype = kv_cache_dtype ,
2025-03-29 09:31:38 +00:00
k_scale = kv_scales . key_scale_cpu ,
v_scale = kv_scales . value_scale_cpu ,
2025-01-17 11:04:57 +00:00
)
2024-10-04 15:51:48 +00:00
else :
2025-01-17 13:13:29 +00:00
paged_reshape_and_cache (
key ,
value ,
key_cache ,
value_cache ,
slots ,
kv_scales . key_scale_cpu ,
kv_scales . value_scale_cpu ,
)
2024-10-17 12:54:22 +00:00
def paged_reshape_and_cache (
key : torch . Tensor ,
value : torch . Tensor ,
key_cache : torch . Tensor ,
value_cache : torch . Tensor ,
slots : torch . Tensor ,
2025-01-17 13:13:29 +00:00
k_scale : float = 1.0 ,
v_scale : float = 1.0 ,
2024-10-17 12:54:22 +00:00
) :
2025-01-17 13:13:29 +00:00
2024-10-17 12:54:22 +00:00
if SYSTEM == " cuda " :
2025-01-27 10:42:36 +00:00
kv_cache_dtype = " auto "
if key_cache . dtype == torch . float8_e4m3fn :
kv_cache_dtype = " fp8 "
2025-02-10 18:19:25 +00:00
paged_attention . reshape_and_cache (
key ,
value ,
key_cache ,
value_cache ,
slots ,
kv_cache_dtype ,
torch . tensor ( k_scale ) ,
torch . tensor ( v_scale ) ,
2024-10-17 12:54:22 +00:00
)
elif SYSTEM == " rocm " :
try :
import vllm . _custom_ops as ops
except Exception as e :
raise ImportError (
f " Could not import vllm paged attention. Make sure your installation is correct. Complete error: { e } "
)
2025-01-17 13:13:29 +00:00
kv_cache_dtype = " auto "
if key_cache . dtype == torch . float8_e4m3fn :
key_cache = key_cache . view ( torch . uint8 )
value_cache = value_cache . view ( torch . uint8 )
kv_cache_dtype = " fp8 "
2024-12-18 11:44:42 +00:00
ops . reshape_and_cache (
2025-01-17 13:13:29 +00:00
key , value , key_cache , value_cache , slots , kv_cache_dtype , k_scale , v_scale
2024-12-18 11:44:42 +00:00
)
2024-10-17 12:54:22 +00:00
elif SYSTEM == " ipex " :
import intel_extension_for_pytorch as ipex
2025-04-03 02:29:01 +00:00
kv_cache_dtype = " auto "
if key_cache . dtype == torch . float8_e5m2 :
kv_cache_dtype = " fp8_e5m2 "
if key_cache . dtype == torch . float8_e4m3fn :
kv_cache_dtype = " fp8_e4m3 "
2024-10-17 12:54:22 +00:00
ipex . llm . modules . PagedAttention . reshape_and_cache (
2025-04-03 02:29:01 +00:00
key ,
value ,
key_cache ,
value_cache ,
slots ,
kv_cache_dtype = kv_cache_dtype ,
k_scale = k_scale ,
v_scale = v_scale ,
2024-10-17 12:54:22 +00:00
)
else :
raise NotImplementedError (
2024-10-24 14:36:18 +00:00
f " Cannot reshape and cache for paged attention, system ' { SYSTEM } ' not supported "
2024-10-17 12:54:22 +00:00
)
2024-10-24 14:36:18 +00:00
def get_kv_scales ( weights : Weights , prefix : str ) - > KVScales :
""" Load KV cache scales. """
key_scale = torch . tensor ( 1.0 , dtype = torch . float32 , device = weights . device )
value_scale = key_scale
if weights . has_tensor ( f " { prefix } .k_scale " ) and weights . has_tensor (
f " { prefix } .v_scale "
) :
key_scale = weights . get_tensor ( f " { prefix } .k_scale " , to_dtype = False ) . float ( )
value_scale = weights . get_tensor ( f " { prefix } .v_scale " , to_dtype = False ) . float ( )
elif weights . has_tensor ( f " { prefix } .kv_scale " ) :
# Fall back to older more coarse-grained scale when available.
key_scale = weights . get_tensor ( f " { prefix } .kv_scale " ) . float ( )
value_scale = key_scale
return KVScales ( key_scale = key_scale , value_scale = value_scale )