mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
fp8 kv cache
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
2007269fe7
commit
1cda91135e
@ -26,6 +26,11 @@ class Dtype(str, Enum):
|
|||||||
bloat16 = "bfloat16"
|
bloat16 = "bfloat16"
|
||||||
|
|
||||||
|
|
||||||
|
class KVCacheDtype(str, Enum):
|
||||||
|
fp8_e4m3fn = "fp8_e4m3fn"
|
||||||
|
fp8_e5m2 = "fp8_e5m2"
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -34,6 +39,7 @@ def serve(
|
|||||||
quantize: Optional[Quantization] = None,
|
quantize: Optional[Quantization] = None,
|
||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
dtype: Optional[Dtype] = None,
|
dtype: Optional[Dtype] = None,
|
||||||
|
kv_cache_dtype: Optional[KVCacheDtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
uds_path: Path = "/tmp/text-generation-server",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
@ -93,7 +99,8 @@ def serve(
|
|||||||
# Downgrade enum into str for easier management later on
|
# Downgrade enum into str for easier management later on
|
||||||
quantize = None if quantize is None else quantize.value
|
quantize = None if quantize is None else quantize.value
|
||||||
dtype = "bfloat16" if dtype is None else dtype.value
|
dtype = "bfloat16" if dtype is None else dtype.value
|
||||||
logger.info(f"quantize={quantize}")
|
kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value
|
||||||
|
logger.info(f"quantize={quantize} kv_cache_dtype={kv_cache_dtype}")
|
||||||
if dtype is not None and quantize not in {
|
if dtype is not None and quantize not in {
|
||||||
None,
|
None,
|
||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
@ -175,6 +182,7 @@ def serve(
|
|||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
uds_path,
|
uds_path,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
|
@ -11,11 +11,61 @@ import os
|
|||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
|
|
||||||
def fetch_from_cache(cache, blocks):
|
class FP8Matmul(torch.nn.Module):
|
||||||
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
|
||||||
return cache[: blocks.size(0)]
|
def __init__(self, scale_other):
|
||||||
else:
|
super().__init__()
|
||||||
return cache.index_select(0, blocks)
|
self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
|
||||||
|
self.scale_other = scale_other
|
||||||
|
|
||||||
|
def quant_input(self, x, scale):
|
||||||
|
return torch.ops.hpu.cast_to_fp8_v2(
|
||||||
|
x, scale, False, False, torch.float8_e4m3fn
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
def matmul_fp8(
|
||||||
|
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
|
||||||
|
):
|
||||||
|
return torch.ops.hpu.fp8_gemm_v2(
|
||||||
|
A=x,
|
||||||
|
trans_A=False,
|
||||||
|
B=other,
|
||||||
|
trans_B=False,
|
||||||
|
D=None,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
A_scale_inv=scale_input_inv,
|
||||||
|
B_scale_inv=scale_other_inv,
|
||||||
|
bias=None,
|
||||||
|
accumulate=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input, other):
|
||||||
|
qinput = self.quant_input(input, self.scale_input)
|
||||||
|
qother = self.quant_input(other, self.scale_other)
|
||||||
|
output = self.matmul_fp8(
|
||||||
|
qinput,
|
||||||
|
qother,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
scale_input_inv=1.0 / self.scale_input,
|
||||||
|
scale_other_inv=1.0 / self.scale_other,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class FetchFromCache(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, scale_inv):
|
||||||
|
super().__init__()
|
||||||
|
self.scale_inv = scale_inv
|
||||||
|
|
||||||
|
def forward(self, cache, blocks):
|
||||||
|
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
||||||
|
out = cache[: blocks.size(0)]
|
||||||
|
else:
|
||||||
|
out = cache.index_select(0, blocks)
|
||||||
|
if out.dtype == torch.float8_e4m3fn:
|
||||||
|
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
@ -67,6 +117,7 @@ def paged_attention(
|
|||||||
hpu_attention_meta: HPUPagedAttentionMetadata,
|
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||||
):
|
):
|
||||||
batch_size, head_num, head_size = query.shape
|
batch_size, head_num, head_size = query.shape
|
||||||
|
fp8_kv = kv_cache.key.dtype == torch.float8_e4m3fn
|
||||||
output = ops.flat_pa(
|
output = ops.flat_pa(
|
||||||
query=query.view(batch_size, 1, head_num * head_size),
|
query=query.view(batch_size, 1, head_num * head_size),
|
||||||
key_cache=kv_cache.key,
|
key_cache=kv_cache.key,
|
||||||
@ -76,12 +127,12 @@ def paged_attention(
|
|||||||
block_bias=hpu_attention_meta.attn_bias,
|
block_bias=hpu_attention_meta.attn_bias,
|
||||||
block_groups=hpu_attention_meta.block_groups,
|
block_groups=hpu_attention_meta.block_groups,
|
||||||
scale=softmax_scale,
|
scale=softmax_scale,
|
||||||
matmul_qk_op=Matmul(),
|
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
||||||
matmul_av_op=Matmul(),
|
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
||||||
batch2block_matmul_op=Matmul(),
|
batch2block_matmul_op=Matmul(),
|
||||||
block2batch_matmul_op=Matmul(),
|
block2batch_matmul_op=Matmul(),
|
||||||
keys_fetch_func=fetch_from_cache,
|
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
|
||||||
values_fetch_func=fetch_from_cache,
|
values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
|
||||||
)
|
)
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
return output.view(batch_size, head_num, head_size)
|
return output.view(batch_size, head_num, head_size)
|
||||||
|
@ -50,6 +50,8 @@ class KVCache:
|
|||||||
):
|
):
|
||||||
"""Construct the key-value cache for a layer."""
|
"""Construct the key-value cache for a layer."""
|
||||||
## TODO FP8 kv cache support
|
## TODO FP8 kv cache support
|
||||||
|
if dtype is torch.float8_e5m2:
|
||||||
|
raise ValueError("torch.float8_e5m2 is not supported in hpu. ")
|
||||||
|
|
||||||
self.kv_cache = (
|
self.kv_cache = (
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
@ -101,8 +103,8 @@ class KVCache:
|
|||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
slots,
|
slots,
|
||||||
kv_scales.key_scale_cpu,
|
kv_scales.key_scale,
|
||||||
kv_scales.value_scale_cpu,
|
kv_scales.value_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -112,11 +114,18 @@ def paged_reshape_and_cache(
|
|||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
k_scale: float = 1.0,
|
k_scale: torch.Tensor,
|
||||||
v_scale: float = 1.0,
|
v_scale: torch.Tensor,
|
||||||
):
|
):
|
||||||
block_idx = slots // BLOCK_SIZE
|
block_idx = slots // BLOCK_SIZE
|
||||||
block_offset = slots % BLOCK_SIZE
|
block_offset = slots % BLOCK_SIZE
|
||||||
|
if key_cache.dtype == torch.float8_e4m3fn:
|
||||||
|
key = torch.ops.hpu.cast_to_fp8_v2(
|
||||||
|
key, k_scale, False, False, torch.float8_e4m3fn
|
||||||
|
)[0]
|
||||||
|
value = torch.ops.hpu.cast_to_fp8_v2(
|
||||||
|
value, v_scale, False, False, torch.float8_e4m3fn
|
||||||
|
)[0]
|
||||||
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
||||||
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
||||||
|
|
||||||
|
@ -324,6 +324,7 @@ def get_model(
|
|||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
speculate: Optional[int],
|
speculate: Optional[int],
|
||||||
dtype: Optional[torch.dtype],
|
dtype: Optional[torch.dtype],
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
@ -449,7 +450,12 @@ def get_model(
|
|||||||
|
|
||||||
model_type = config_dict["model_type"]
|
model_type = config_dict["model_type"]
|
||||||
|
|
||||||
kv_cache_dtype = dtype
|
if kv_cache_dtype == "fp8_e4m3fn":
|
||||||
|
kv_cache_dtype = torch.float8_e4m3fn
|
||||||
|
elif kv_cache_dtype == "fp8_e5m2":
|
||||||
|
kv_cache_dtype = torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = dtype
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
if model_type == DEEPSEEK_V2:
|
if model_type == DEEPSEEK_V2:
|
||||||
@ -890,6 +896,7 @@ def get_model_with_lora_adapters(
|
|||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
speculate: Optional[int],
|
speculate: Optional[int],
|
||||||
dtype: Optional[torch.dtype],
|
dtype: Optional[torch.dtype],
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
adapter_to_index: Dict[str, int],
|
adapter_to_index: Dict[str, int],
|
||||||
@ -903,6 +910,7 @@ def get_model_with_lora_adapters(
|
|||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
)
|
)
|
||||||
|
@ -1438,6 +1438,7 @@ class FlashCausalLM(Model):
|
|||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
|
||||||
self.bucketing_ctx = None
|
self.bucketing_ctx = None
|
||||||
|
htorch.core.hpu_set_env()
|
||||||
if htorch.utils.internal.is_lazy():
|
if htorch.utils.internal.is_lazy():
|
||||||
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
environment.set_model_config(self.config)
|
environment.set_model_config(self.config)
|
||||||
|
@ -224,6 +224,7 @@ def serve(
|
|||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
speculate: Optional[int],
|
speculate: Optional[int],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
@ -236,6 +237,7 @@ def serve(
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
|
kv_cache_dtype: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if not is_driver_compatible():
|
if not is_driver_compatible():
|
||||||
@ -279,6 +281,7 @@ def serve(
|
|||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
data_type,
|
data_type,
|
||||||
|
kv_cache_dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
adapter_to_index,
|
adapter_to_index,
|
||||||
@ -326,6 +329,7 @@ def serve(
|
|||||||
quantize,
|
quantize,
|
||||||
speculate,
|
speculate,
|
||||||
dtype,
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user