mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
* clean cuda/rocm code in hpu backend, enable flat_hpu Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix TP in pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * adjust block table in hpu to improve performance Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable all the model. not testet yet Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * use tensor cache in hpu graph to avoid replay issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add moe support, fix qwen/mistral/mixtral crash Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix phimoe issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * gpt_bigcode could also go pageattn Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable dbrx remove some unused code Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * multi-modality initial PR Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * adjust warmup and enable vlm Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix incorrect output in qwen2 idefics if hpu graph is used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove unused quantization code and enable awq/gptq int4 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix gptq issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * enable fp8 Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * warmup prefill remove model where pageattn is not used, set block table to None since it's not used Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add warmup_decode Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * warmup decode Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove block_tables and prefill_cache_indices which will lead to dynamic shape Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix comment Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * missing gptj change... Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix some issue Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * remove torch.where to fix incorrect output in hpu graph model Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * match the latest vllm_extension ops Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
68 lines
1.9 KiB
Python
68 lines
1.9 KiB
Python
import torch
|
|
from torch import nn
|
|
from accelerate import init_empty_weights
|
|
|
|
|
|
# Monkey patching
|
|
@classmethod
|
|
def load_layer_norm(cls, prefix, weights, eps):
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
bias = weights.get_tensor(f"{prefix}.bias")
|
|
with init_empty_weights():
|
|
ln = cls(weight.shape, eps=eps)
|
|
|
|
ln.weight = torch.nn.Parameter(weight)
|
|
ln.bias = torch.nn.Parameter(bias)
|
|
return ln
|
|
|
|
|
|
@classmethod
|
|
def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
with init_empty_weights():
|
|
ln = cls(weight.shape, eps=eps)
|
|
|
|
ln.weight = torch.nn.Parameter(weight)
|
|
ln.bias = None
|
|
return ln
|
|
|
|
|
|
torch.nn.LayerNorm.load = load_layer_norm
|
|
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
|
|
|
|
|
class FastLayerNorm(nn.LayerNorm):
|
|
def forward(self, hidden_states, residual=None):
|
|
if residual is not None:
|
|
hidden_states += residual
|
|
residual = hidden_states
|
|
|
|
return super().forward(hidden_states), residual
|
|
|
|
|
|
class FastRMSNorm(nn.Module):
|
|
def __init__(self, weight: torch.Tensor, eps: float):
|
|
super().__init__()
|
|
|
|
self.weight = nn.Parameter(weight)
|
|
self.variance_epsilon = eps
|
|
|
|
@classmethod
|
|
def load(cls, prefix, weights, eps=1e-6):
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
return cls(weight, eps)
|
|
|
|
def forward(self, hidden_states, residual=None):
|
|
from vllm_hpu_extension.kernels import rms_norm
|
|
|
|
orig_shape = hidden_states.shape
|
|
if residual is not None:
|
|
residual += hidden_states.view(residual.shape)
|
|
else:
|
|
residual = hidden_states
|
|
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
|
|
if len(orig_shape) == 2:
|
|
residual = residual.unsqueeze(0)
|
|
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
|
|
return x.view(orig_shape), residual.view(orig_shape)
|