text-generation-inference/backends/gaudi/server/text_generation_server/layers/attention/hpu.py
2025-05-22 15:29:16 +02:00

180 lines
5.8 KiB
Python

import torch
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from typing import Optional
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from vllm_hpu_extension import ops
from vllm_hpu_extension.utils import Matmul
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import os
from text_generation_server.models.globals import BLOCK_SIZE
SUPPORTS_WINDOWING = False
class FP8Matmul(torch.nn.Module):
def __init__(self, scale_other):
super().__init__()
self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device="hpu")
self.scale_other = scale_other
def quant_input(self, x, scale):
return torch.ops.hpu.cast_to_fp8_v2(
x, scale, False, False, torch.float8_e4m3fn
)[0]
def matmul_fp8(
self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None
):
return torch.ops.hpu.fp8_gemm_v2(
A=x,
trans_A=False,
B=other,
trans_B=False,
D=None,
out_dtype=out_dtype,
A_scale_inv=scale_input_inv,
B_scale_inv=scale_other_inv,
bias=None,
accumulate=False,
)
def forward(self, input, other):
qinput = self.quant_input(input, self.scale_input)
qother = self.quant_input(other, self.scale_other)
output = self.matmul_fp8(
qinput,
qother,
out_dtype=torch.bfloat16,
scale_input_inv=1.0 / self.scale_input,
scale_other_inv=1.0 / self.scale_other,
)
return output
class FetchFromCache(torch.nn.Module):
def __init__(self, scale_inv):
super().__init__()
self.scale_inv = scale_inv
def forward(self, cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
out = cache[: blocks.size(0)]
else:
out = cache.index_select(0, blocks)
if out.dtype == torch.float8_e4m3fn:
out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)
return out
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
bs = seqlen.input_lengths.shape[0]
_, head_num, head_size = query.shape
_, kv_head_num, head_size = key.shape
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
scale=softmax_scale,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=seqlen.input_lengths,
padding_side="left",
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
return attn_output
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key,
value_cache=kv_cache.value,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
block_size=BLOCK_SIZE,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),
)
# Reshape the output tensor.
return output.view(batch_size, head_num, head_size)
def paged_attention_mla(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
kv_lora_rank: int = 0,
):
batch_size, head_num, head_size = query.shape
fp8_kv = kv_cache.dtype == torch.float8_e4m3fn
output = ops.flat_pa_mla(
query=query,
key_cache=kv_cache.key,
value_cache=None,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_groups=hpu_attention_meta.block_groups,
block_size=BLOCK_SIZE,
scale=softmax_scale,
matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),
matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),
values_fetch_func=None,
kv_lora_rank=kv_lora_rank,
)
# Reshape the output tensor.
return output.view(batch_size, head_num, -1)
__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"]