fp8 kv cache

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-05-05 17:44:28 -07:00
parent 2007269fe7
commit 1cda91135e
6 changed files with 96 additions and 15 deletions

View File

@ -26,6 +26,11 @@ class Dtype(str, Enum):
bloat16 = "bfloat16" bloat16 = "bfloat16"
class KVCacheDtype(str, Enum):
fp8_e4m3fn = "fp8_e4m3fn"
fp8_e5m2 = "fp8_e5m2"
@app.command() @app.command()
def serve( def serve(
model_id: str, model_id: str,
@ -34,6 +39,7 @@ def serve(
quantize: Optional[Quantization] = None, quantize: Optional[Quantization] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
dtype: Optional[Dtype] = None, dtype: Optional[Dtype] = None,
kv_cache_dtype: Optional[KVCacheDtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO", logger_level: str = "INFO",
@ -93,7 +99,8 @@ def serve(
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
dtype = "bfloat16" if dtype is None else dtype.value dtype = "bfloat16" if dtype is None else dtype.value
logger.info(f"quantize={quantize}") kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
if dtype is not None and quantize not in { if dtype is not None and quantize not in {
None, None,
"bitsandbytes", "bitsandbytes",
@ -175,6 +182,7 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
uds_path, uds_path,
max_input_tokens, max_input_tokens,

View File

@ -11,11 +11,61 @@ import os
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
def fetch_from_cache(cache, blocks): class FP8Matmul(torch.nn.Module):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
return cache[: blocks.size(0)] def __init__(self, scale_other):
else: super().__init__()
return cache.index_select(0, blocks) self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
self.scale_other = scale_other
def quant_input(self, x, scale):
return torch.ops.hpu.cast_to_fp8_v2(
x, scale, False, False, torch.float8_e4m3fn
)[0]
def matmul_fp8(
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
):
return torch.ops.hpu.fp8_gemm_v2(
A=x,
trans_A=False,
B=other,
trans_B=False,
D=None,
out_dtype=out_dtype,
A_scale_inv=scale_input_inv,
B_scale_inv=scale_other_inv,
bias=None,
accumulate=False,
)
def forward(self, input, other):
qinput = self.quant_input(input, self.scale_input)
qother = self.quant_input(other, self.scale_other)
output = self.matmul_fp8(
qinput,
qother,
out_dtype=torch.bfloat16,
scale_input_inv=1.0 / self.scale_input,
scale_other_inv=1.0 / self.scale_other,
)
return output
class FetchFromCache(torch.nn.Module):
def __init__(self, scale_inv):
super().__init__()
self.scale_inv = scale_inv
def forward(self, cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
out = cache[: blocks.size(0)]
else:
out = cache.index_select(0, blocks)
if out.dtype == torch.float8_e4m3fn:
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
return out
def attention( def attention(
@ -67,6 +117,7 @@ def paged_attention(
hpu_attention_meta: HPUPagedAttentionMetadata, hpu_attention_meta: HPUPagedAttentionMetadata,
): ):
batch_size, head_num, head_size = query.shape batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.key.dtype == torch.float8_e4m3fn
output = ops.flat_pa( output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size), query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key, key_cache=kv_cache.key,
@ -76,12 +127,12 @@ def paged_attention(
block_bias=hpu_attention_meta.attn_bias, block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups, block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale, scale=softmax_scale,
matmul_qk_op=Matmul(), matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=Matmul(), matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(), batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(), block2batch_matmul_op=Matmul(),
keys_fetch_func=fetch_from_cache, keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=fetch_from_cache, values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
) )
# Reshape the output tensor. # Reshape the output tensor.
return output.view(batch_size, head_num, head_size) return output.view(batch_size, head_num, head_size)

View File

@ -50,6 +50,8 @@ class KVCache:
): ):
"""Construct the key-value cache for a layer.""" """Construct the key-value cache for a layer."""
## TODO FP8 kv cache support ## TODO FP8 kv cache support
if dtype is torch.float8_e5m2:
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
self.kv_cache = ( self.kv_cache = (
torch.zeros( torch.zeros(
@ -101,8 +103,8 @@ class KVCache:
key_cache, key_cache,
value_cache, value_cache,
slots, slots,
kv_scales.key_scale_cpu, kv_scales.key_scale,
kv_scales.value_scale_cpu, kv_scales.value_scale,
) )
@ -112,11 +114,18 @@ def paged_reshape_and_cache(
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
k_scale: float = 1.0, k_scale: torch.Tensor,
v_scale: float = 1.0, v_scale: torch.Tensor,
): ):
block_idx = slots // BLOCK_SIZE block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE block_offset = slots % BLOCK_SIZE
if key_cache.dtype == torch.float8_e4m3fn:
key = torch.ops.hpu.cast_to_fp8_v2(
key, k_scale, False, False, torch.float8_e4m3fn
)[0]
value = torch.ops.hpu.cast_to_fp8_v2(
value, v_scale, False, False, torch.float8_e4m3fn
)[0]
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset) cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset) cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)

View File

@ -324,6 +324,7 @@ def get_model(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[torch.dtype], dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int, max_input_tokens: int,
) -> Model: ) -> Model:
@ -449,7 +450,12 @@ def get_model(
model_type = config_dict["model_type"] model_type = config_dict["model_type"]
kv_cache_dtype = dtype if kv_cache_dtype == "fp8_e4m3fn":
kv_cache_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2
else:
kv_cache_dtype = dtype
if FLASH_ATTENTION: if FLASH_ATTENTION:
if model_type == DEEPSEEK_V2: if model_type == DEEPSEEK_V2:
@ -890,6 +896,7 @@ def get_model_with_lora_adapters(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[torch.dtype], dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int, max_input_tokens: int,
adapter_to_index: Dict[str, int], adapter_to_index: Dict[str, int],
@ -903,6 +910,7 @@ def get_model_with_lora_adapters(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
) )

View File

@ -1438,6 +1438,7 @@ class FlashCausalLM(Model):
self.kv_cache = [] self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
self.bucketing_ctx = None self.bucketing_ctx = None
htorch.core.hpu_set_env()
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
environment.set_model_config(self.config) environment.set_model_config(self.config)

View File

@ -224,6 +224,7 @@ def serve(
quantize: Optional[str], quantize: Optional[str],
speculate: Optional[int], speculate: Optional[int],
dtype: Optional[str], dtype: Optional[str],
kv_cache_dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
max_input_tokens: int, max_input_tokens: int,
@ -236,6 +237,7 @@ def serve(
quantize: Optional[str] = None, quantize: Optional[str] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
dtype: Optional[str] = None, dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if not is_driver_compatible(): if not is_driver_compatible():
@ -279,6 +281,7 @@ def serve(
quantize, quantize,
speculate, speculate,
data_type, data_type,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
adapter_to_index, adapter_to_index,
@ -326,6 +329,7 @@ def serve(
quantize, quantize,
speculate, speculate,
dtype, dtype,
kv_cache_dtype,
trust_remote_code, trust_remote_code,
) )
) )