mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
* clean cuda/rocm code in hpu backend, enable flat_hpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix TP in pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * adjust block table in hpu to improve performance Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable all the model. not testet yet Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * use tensor cache in hpu graph to avoid replay issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add moe support, fix qwen/mistral/mixtral crash Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix phimoe issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * gpt_bigcode could also go pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable dbrx remove some unused code Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * multi-modality initial PR Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * adjust warmup and enable vlm Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix incorrect output in qwen2 idefics if hpu graph is used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove unused quantization code and enable awq/gptq int4 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptq issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable fp8 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * warmup prefill remove model where pageattn is not used, set block table to None since it's not used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add warmup_decode Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * warmup decode Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove block_tables and prefill_cache_indices which will lead to dynamic shape Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix comment Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * missing gptj change... Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix some issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove torch.where to fix incorrect output in hpu graph model Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * match the latest vllm_extension ops Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
96 lines
2.8 KiB
Python
96 lines
2.8 KiB
Python
import torch
|
|
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
|
from typing import Optional
|
|
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
|
from vllm_hpu_extension import ops
|
|
from vllm_hpu_extension.utils import Matmul
|
|
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
|
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
|
import os
|
|
|
|
SUPPORTS_WINDOWING = False
|
|
|
|
|
|
def fetch_from_cache(cache, blocks):
|
|
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
|
return cache[: blocks.size(0)]
|
|
else:
|
|
return cache.index_select(0, blocks)
|
|
|
|
|
|
def attention(
|
|
*,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: KVCache,
|
|
kv_scales: KVScales,
|
|
seqlen: Seqlen,
|
|
softmax_scale: float,
|
|
window_size_left: int = -1,
|
|
causal: bool = True,
|
|
softcap: Optional[float] = None,
|
|
):
|
|
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
|
bs = seqlen.input_lengths.shape[0]
|
|
_, head_num, head_size = query.shape
|
|
_, kv_head_num, head_size = key.shape
|
|
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
|
|
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
|
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
|
attn_output = fsdpa_op(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask=None,
|
|
dropout_p=0.0,
|
|
is_causal=causal,
|
|
scale=softmax_scale,
|
|
softmax_mode="None",
|
|
recompute_mode=None,
|
|
valid_sequence_lengths=seqlen.input_lengths,
|
|
padding_side="left",
|
|
)
|
|
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
|
return attn_output
|
|
|
|
|
|
def paged_attention(
|
|
query: torch.Tensor,
|
|
kv_cache: KVCache,
|
|
kv_head_mapping: torch.Tensor,
|
|
softmax_scale: float,
|
|
seqlen: Seqlen,
|
|
*,
|
|
kv_scales: KVScales,
|
|
softcap: Optional[float] = None,
|
|
hpu_attention_meta: HPUPagedAttentionMetadata,
|
|
):
|
|
batch_size, head_num, head_size = query.shape
|
|
output = ops.flat_pa(
|
|
query=query.view(batch_size, 1, head_num * head_size),
|
|
key_cache=kv_cache.key,
|
|
value_cache=kv_cache.value,
|
|
block_list=hpu_attention_meta.block_list,
|
|
block_mapping=hpu_attention_meta.block_mapping,
|
|
block_bias=hpu_attention_meta.attn_bias,
|
|
block_scales=hpu_attention_meta.block_scales,
|
|
block_groups=hpu_attention_meta.block_groups,
|
|
scale=softmax_scale,
|
|
matmul_qk_op=Matmul(),
|
|
matmul_av_op=Matmul(),
|
|
batch2block_matmul_op=Matmul(),
|
|
block2batch_matmul_op=Matmul(),
|
|
keys_fetch_func=fetch_from_cache,
|
|
values_fetch_func=fetch_from_cache,
|
|
)
|
|
# Reshape the output tensor.
|
|
return output.view(batch_size, head_num, head_size)
|
|
|
|
|
|
__all__ = [
|
|
"SUPPORTS_WINDOWING",
|
|
"attention",
|
|
"paged_attention",
|
|
]
|