mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-23 20:12:06 +00:00
180 lines
5.8 KiB
Python
180 lines
5.8 KiB
Python
import torch
|
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
|
from typing import Optional
|
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
|
from vllm_hpu_extension import ops
|
|
from vllm_hpu_extension.utils import Matmul
|
|
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
|
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
|
import os
|
|
from text_generation_server.models.globals import BLOCK_SIZE
|
|
|
|
SUPPORTS_WINDOWING = False
|
|
|
|
|
|
class FP8Matmul(torch.nn.Module):
|
|
|
|
def __init__(self, scale_other):
|
|
super().__init__()
|
|
self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
|
|
self.scale_other = scale_other
|
|
|
|
def quant_input(self, x, scale):
|
|
return torch.ops.hpu.cast_to_fp8_v2(
|
|
x, scale, False, False, torch.float8_e4m3fn
|
|
)[0]
|
|
|
|
def matmul_fp8(
|
|
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
|
|
):
|
|
return torch.ops.hpu.fp8_gemm_v2(
|
|
A=x,
|
|
trans_A=False,
|
|
B=other,
|
|
trans_B=False,
|
|
D=None,
|
|
out_dtype=out_dtype,
|
|
A_scale_inv=scale_input_inv,
|
|
B_scale_inv=scale_other_inv,
|
|
bias=None,
|
|
accumulate=False,
|
|
)
|
|
|
|
def forward(self, input, other):
|
|
qinput = self.quant_input(input, self.scale_input)
|
|
qother = self.quant_input(other, self.scale_other)
|
|
output = self.matmul_fp8(
|
|
qinput,
|
|
qother,
|
|
out_dtype=torch.bfloat16,
|
|
scale_input_inv=1.0 / self.scale_input,
|
|
scale_other_inv=1.0 / self.scale_other,
|
|
)
|
|
return output
|
|
|
|
|
|
class FetchFromCache(torch.nn.Module):
|
|
|
|
def __init__(self, scale_inv):
|
|
super().__init__()
|
|
self.scale_inv = scale_inv
|
|
|
|
def forward(self, cache, blocks):
|
|
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
|
out = cache[: blocks.size(0)]
|
|
else:
|
|
out = cache.index_select(0, blocks)
|
|
if out.dtype == torch.float8_e4m3fn:
|
|
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
|
|
return out
|
|
|
|
|
|
def attention(
|
|
*,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: KVCache,
|
|
kv_scales: KVScales,
|
|
seqlen: Seqlen,
|
|
softmax_scale: float,
|
|
window_size_left: int = -1,
|
|
causal: bool = True,
|
|
softcap: Optional[float] = None,
|
|
):
|
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
|
bs = seqlen.input_lengths.shape[0]
|
|
_, head_num, head_size = query.shape
|
|
_, kv_head_num, head_size = key.shape
|
|
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
|
|
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
|
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
|
attn_output = fsdpa_op(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask=None,
|
|
dropout_p=0.0,
|
|
is_causal=causal,
|
|
scale=softmax_scale,
|
|
softmax_mode="None",
|
|
recompute_mode=None,
|
|
valid_sequence_lengths=seqlen.input_lengths,
|
|
padding_side="left",
|
|
)
|
|
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
|
return attn_output
|
|
|
|
|
|
def paged_attention(
|
|
query: torch.Tensor,
|
|
kv_cache: KVCache,
|
|
kv_head_mapping: torch.Tensor,
|
|
softmax_scale: float,
|
|
seqlen: Seqlen,
|
|
*,
|
|
kv_scales: KVScales,
|
|
softcap: Optional[float] = None,
|
|
hpu_attention_meta: HPUPagedAttentionMetadata,
|
|
):
|
|
batch_size, head_num, head_size = query.shape
|
|
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
|
|
output = ops.flat_pa(
|
|
query=query.view(batch_size, 1, head_num * head_size),
|
|
key_cache=kv_cache.key,
|
|
value_cache=kv_cache.value,
|
|
block_list=hpu_attention_meta.block_list,
|
|
block_mapping=hpu_attention_meta.block_mapping,
|
|
block_bias=hpu_attention_meta.attn_bias,
|
|
block_groups=hpu_attention_meta.block_groups,
|
|
block_size=BLOCK_SIZE,
|
|
scale=softmax_scale,
|
|
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
|
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
|
batch2block_matmul_op=Matmul(),
|
|
block2batch_matmul_op=Matmul(),
|
|
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
|
|
values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
|
|
)
|
|
# Reshape the output tensor.
|
|
return output.view(batch_size, head_num, head_size)
|
|
|
|
|
|
def paged_attention_mla(
|
|
query: torch.Tensor,
|
|
kv_cache: KVCache,
|
|
kv_head_mapping: torch.Tensor,
|
|
softmax_scale: float,
|
|
seqlen: Seqlen,
|
|
*,
|
|
kv_scales: KVScales,
|
|
softcap: Optional[float] = None,
|
|
hpu_attention_meta: HPUPagedAttentionMetadata,
|
|
kv_lora_rank: int = 0,
|
|
):
|
|
batch_size, head_num, head_size = query.shape
|
|
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
|
|
output = ops.flat_pa_mla(
|
|
query=query,
|
|
key_cache=kv_cache.key,
|
|
value_cache=None,
|
|
block_list=hpu_attention_meta.block_list,
|
|
block_mapping=hpu_attention_meta.block_mapping,
|
|
block_bias=hpu_attention_meta.attn_bias,
|
|
block_groups=hpu_attention_meta.block_groups,
|
|
block_size=BLOCK_SIZE,
|
|
scale=softmax_scale,
|
|
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
|
|
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
|
|
batch2block_matmul_op=Matmul(),
|
|
block2batch_matmul_op=Matmul(),
|
|
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
|
|
values_fetch_func=None,
|
|
kv_lora_rank=kv_lora_rank,
|
|
)
|
|
# Reshape the output tensor.
|
|
return output.view(batch_size, head_num, -1)
|
|
|
|
|
|
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]
|