text-generation-inference/backends/gaudi/server/text_generation_server/layers/attention/hpu.py

96 lines
2.8 KiB
Python
Raw Permalink Normal View History

Gaudi: clean cuda/rocm code in hpu backend, enable flat_hpu (#3113) * clean cuda/rocm code in hpu backend, enable flat_hpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix TP in pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * adjust block table in hpu to improve performance Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable all the model. not testet yet Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * use tensor cache in hpu graph to avoid replay issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add moe support, fix qwen/mistral/mixtral crash Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix phimoe issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * gpt_bigcode could also go pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable dbrx remove some unused code Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * multi-modality initial PR Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * adjust warmup and enable vlm Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix incorrect output in qwen2 idefics if hpu graph is used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove unused quantization code and enable awq/gptq int4 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptq issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable fp8 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * warmup prefill remove model where pageattn is not used, set block table to None since it's not used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add warmup_decode Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * warmup decode Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove block_tables and prefill_cache_indices which will lead to dynamic shape Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix comment Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * missing gptj change... Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix some issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove torch.where to fix incorrect output in hpu graph model Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * match the latest vllm_extension ops Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2025-04-14 13:58:13 +00:00
import torch
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from typing import Optional
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from vllm_hpu_extension import ops
from vllm_hpu_extension.utils import Matmul
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import os
SUPPORTS_WINDOWING = False
def fetch_from_cache(cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
return cache[: blocks.size(0)]
else:
return cache.index_select(0, blocks)
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
bs = seqlen.input_lengths.shape[0]
_, head_num, head_size = query.shape
_, kv_head_num, head_size = key.shape
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
scale=softmax_scale,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=seqlen.input_lengths,
padding_side="left",
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
return attn_output
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
):
batch_size, head_num, head_size = query.shape
output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key,
value_cache=kv_cache.value,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_scales=hpu_attention_meta.block_scales,
block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale,
matmul_qk_op=Matmul(),
matmul_av_op=Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=fetch_from_cache,
values_fetch_func=fetch_from_cache,
)
# Reshape the output tensor.
return output.view(batch_size, head_num, head_size)
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
]