mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 03:14:53 +00:00
fp8 kv cache
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
2007269fe7
commit
1cda91135e
@ -26,6 +26,11 @@ class Dtype(str, Enum):
|
||||
bloat16 = "bfloat16"
|
||||
|
||||
|
||||
class KVCacheDtype(str, Enum):
|
||||
fp8_e4m3fn = "fp8_e4m3fn"
|
||||
fp8_e5m2 = "fp8_e5m2"
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
model_id: str,
|
||||
@ -34,6 +39,7 @@ def serve(
|
||||
quantize: Optional[Quantization] = None,
|
||||
speculate: Optional[int] = None,
|
||||
dtype: Optional[Dtype] = None,
|
||||
kv_cache_dtype: Optional[KVCacheDtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
uds_path: Path = "/tmp/text-generation-server",
|
||||
logger_level: str = "INFO",
|
||||
@ -93,7 +99,8 @@ def serve(
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = "bfloat16" if dtype is None else dtype.value
|
||||
logger.info(f"quantize={quantize}")
|
||||
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
|
||||
logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
|
||||
if dtype is not None and quantize not in {
|
||||
None,
|
||||
"bitsandbytes",
|
||||
@ -175,6 +182,7 @@ def serve(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
max_input_tokens,
|
||||
|
@ -11,11 +11,61 @@ import os
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
|
||||
def fetch_from_cache(cache, blocks):
|
||||
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
||||
return cache[: blocks.size(0)]
|
||||
else:
|
||||
return cache.index_select(0, blocks)
|
||||
class FP8Matmul(torch.nn.Module):
|
||||
|
||||
def __init__(self, scale_other):
|
||||
super().__init__()
|
||||
self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
|
||||
self.scale_other = scale_other
|
||||
|
||||
def quant_input(self, x, scale):
|
||||
return torch.ops.hpu.cast_to_fp8_v2(
|
||||
x, scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
|
||||
def matmul_fp8(
|
||||
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
|
||||
):
|
||||
return torch.ops.hpu.fp8_gemm_v2(
|
||||
A=x,
|
||||
trans_A=False,
|
||||
B=other,
|
||||
trans_B=False,
|
||||
D=None,
|
||||
out_dtype=out_dtype,
|
||||
A_scale_inv=scale_input_inv,
|
||||
B_scale_inv=scale_other_inv,
|
||||
bias=None,
|
||||
accumulate=False,
|
||||
)
|
||||
|
||||
def forward(self, input, other):
|
||||
qinput = self.quant_input(input, self.scale_input)
|
||||
qother = self.quant_input(other, self.scale_other)
|
||||
output = self.matmul_fp8(
|
||||
qinput,
|
||||
qother,
|
||||
out_dtype=torch.bfloat16,
|
||||
scale_input_inv=1.0 / self.scale_input,
|
||||
scale_other_inv=1.0 / self.scale_other,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class FetchFromCache(torch.nn.Module):
|
||||
|
||||
def __init__(self, scale_inv):
|
||||
super().__init__()
|
||||
self.scale_inv = scale_inv
|
||||
|
||||
def forward(self, cache, blocks):
|
||||
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
||||
out = cache[: blocks.size(0)]
|
||||
else:
|
||||
out = cache.index_select(0, blocks)
|
||||
if out.dtype == torch.float8_e4m3fn:
|
||||
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
|
||||
return out
|
||||
|
||||
|
||||
def attention(
|
||||
@ -67,6 +117,7 @@ def paged_attention(
|
||||
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||
):
|
||||
batch_size, head_num, head_size = query.shape
|
||||
fp8_kv = kv_cache.key.dtype == torch.float8_e4m3fn
|
||||
output = ops.flat_pa(
|
||||
query=query.view(batch_size, 1, head_num * head_size),
|
||||
key_cache=kv_cache.key,
|
||||
@ -76,12 +127,12 @@ def paged_attention(
|
||||
block_bias=hpu_attention_meta.attn_bias,
|
||||
block_groups=hpu_attention_meta.block_groups,
|
||||
scale=softmax_scale,
|
||||
matmul_qk_op=Matmul(),
|
||||
matmul_av_op=Matmul(),
|
||||
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
||||
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
||||
batch2block_matmul_op=Matmul(),
|
||||
block2batch_matmul_op=Matmul(),
|
||||
keys_fetch_func=fetch_from_cache,
|
||||
values_fetch_func=fetch_from_cache,
|
||||
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
|
||||
values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
|
||||
)
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, head_num, head_size)
|
||||
|
@ -50,6 +50,8 @@ class KVCache:
|
||||
):
|
||||
"""Construct the key-value cache for a layer."""
|
||||
## TODO FP8 kv cache support
|
||||
if dtype is torch.float8_e5m2:
|
||||
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
|
||||
|
||||
self.kv_cache = (
|
||||
torch.zeros(
|
||||
@ -101,8 +103,8 @@ class KVCache:
|
||||
key_cache,
|
||||
value_cache,
|
||||
slots,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
kv_scales.key_scale,
|
||||
kv_scales.value_scale,
|
||||
)
|
||||
|
||||
|
||||
@ -112,11 +114,18 @@ def paged_reshape_and_cache(
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
):
|
||||
block_idx = slots // BLOCK_SIZE
|
||||
block_offset = slots % BLOCK_SIZE
|
||||
if key_cache.dtype == torch.float8_e4m3fn:
|
||||
key = torch.ops.hpu.cast_to_fp8_v2(
|
||||
key, k_scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
value = torch.ops.hpu.cast_to_fp8_v2(
|
||||
value, v_scale, False, False, torch.float8_e4m3fn
|
||||
)[0]
|
||||
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
||||
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
||||
|
||||
|
@ -324,6 +324,7 @@ def get_model(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[torch.dtype],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
) -> Model:
|
||||
@ -449,7 +450,12 @@ def get_model(
|
||||
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
kv_cache_dtype = dtype
|
||||
if kv_cache_dtype == "fp8_e4m3fn":
|
||||
kv_cache_dtype = torch.float8_e4m3fn
|
||||
elif kv_cache_dtype == "fp8_e5m2":
|
||||
kv_cache_dtype = torch.float8_e5m2
|
||||
else:
|
||||
kv_cache_dtype = dtype
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
if model_type == DEEPSEEK_V2:
|
||||
@ -890,6 +896,7 @@ def get_model_with_lora_adapters(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[torch.dtype],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
adapter_to_index: Dict[str, int],
|
||||
@ -903,6 +910,7 @@ def get_model_with_lora_adapters(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
@ -1438,6 +1438,7 @@ class FlashCausalLM(Model):
|
||||
self.kv_cache = []
|
||||
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||
self.bucketing_ctx = None
|
||||
htorch.core.hpu_set_env()
|
||||
if htorch.utils.internal.is_lazy():
|
||||
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||
environment.set_model_config(self.config)
|
||||
|
@ -224,6 +224,7 @@ def serve(
|
||||
quantize: Optional[str],
|
||||
speculate: Optional[int],
|
||||
dtype: Optional[str],
|
||||
kv_cache_dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
max_input_tokens: int,
|
||||
@ -236,6 +237,7 @@ def serve(
|
||||
quantize: Optional[str] = None,
|
||||
speculate: Optional[int] = None,
|
||||
dtype: Optional[str] = None,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if not is_driver_compatible():
|
||||
@ -279,6 +281,7 @@ def serve(
|
||||
quantize,
|
||||
speculate,
|
||||
data_type,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
adapter_to_index,
|
||||
@ -326,6 +329,7 @@ def serve(
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
trust_remote_code,
|
||||
)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user