mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
3489ce7936
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
186 lines
5.7 KiB
Python
186 lines
5.7 KiB
Python
import torch
|
|
from torch import nn
|
|
from accelerate import init_empty_weights
|
|
from text_generation_server.utils.import_utils import (
|
|
SYSTEM,
|
|
)
|
|
|
|
|
|
# Monkey patching
|
|
@classmethod
|
|
def load_layer_norm(cls, prefix, weights, eps):
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
bias = weights.get_tensor(f"{prefix}.bias")
|
|
with init_empty_weights():
|
|
ln = cls(weight.shape, eps=eps)
|
|
|
|
ln.weight = torch.nn.Parameter(weight)
|
|
ln.bias = torch.nn.Parameter(bias)
|
|
return ln
|
|
|
|
|
|
@classmethod
|
|
def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
with init_empty_weights():
|
|
ln = cls(weight.shape, eps=eps)
|
|
|
|
ln.weight = torch.nn.Parameter(weight)
|
|
ln.bias = None
|
|
return ln
|
|
|
|
|
|
torch.nn.LayerNorm.load = load_layer_norm
|
|
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
|
|
|
if SYSTEM == "cuda":
|
|
import dropout_layer_norm
|
|
|
|
class FastLayerNorm(nn.LayerNorm):
|
|
def forward(self, hidden_states, residual=None):
|
|
if hidden_states.shape[-1] > 8192:
|
|
if residual is not None:
|
|
hidden_states += residual
|
|
residual = hidden_states
|
|
|
|
return super(FastLayerNorm, self).forward(hidden_states), residual
|
|
else:
|
|
(
|
|
normed_hidden_states,
|
|
residual,
|
|
*rest,
|
|
) = dropout_layer_norm.dropout_add_ln_fwd(
|
|
hidden_states,
|
|
residual,
|
|
self.weight,
|
|
self.bias,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
0.0,
|
|
self.eps,
|
|
1.0,
|
|
0,
|
|
None,
|
|
False,
|
|
False,
|
|
)
|
|
if residual is None:
|
|
residual = hidden_states
|
|
|
|
return normed_hidden_states, residual
|
|
|
|
elif SYSTEM == "rocm":
|
|
from vllm._C import ops
|
|
|
|
class FastLayerNorm(nn.LayerNorm):
|
|
def forward(self, hidden_states, residual=None):
|
|
if residual is not None:
|
|
hidden_states += residual
|
|
residual = hidden_states
|
|
|
|
return super().forward(hidden_states), residual
|
|
|
|
elif SYSTEM == "xpu":
|
|
import intel_extension_for_pytorch as ipex
|
|
|
|
class FastLayerNorm(nn.LayerNorm):
|
|
def forward(self, hidden_states, residual=None):
|
|
res_out = hidden_states
|
|
out = ipex.llm.functional.add_layer_norm(
|
|
residual, hidden_states, self.weight, self.bias, self.eps, True
|
|
)
|
|
if residual is not None:
|
|
res_out = residual
|
|
return out, res_out
|
|
|
|
|
|
class FastRMSNorm(nn.Module):
|
|
def __init__(self, weight: torch.Tensor, eps: float):
|
|
super().__init__()
|
|
|
|
self.weight = nn.Parameter(weight)
|
|
self.variance_epsilon = eps
|
|
|
|
@classmethod
|
|
def load(cls, prefix, weights, eps=1e-6):
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
return cls(weight, eps)
|
|
|
|
def forward(self, hidden_states, residual=None):
|
|
if SYSTEM == "xpu":
|
|
residual_out = hidden_states
|
|
out = ipex.llm.functional.add_rms_norm(
|
|
residual,
|
|
hidden_states,
|
|
self.weight,
|
|
None,
|
|
self.variance_epsilon,
|
|
True,
|
|
)
|
|
if residual is not None:
|
|
residual_out = residual
|
|
return out, residual_out
|
|
elif hidden_states.shape[-1] > 8192:
|
|
if residual is not None:
|
|
hidden_states += residual
|
|
residual = hidden_states
|
|
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(
|
|
variance + self.variance_epsilon
|
|
)
|
|
|
|
# convert into half-precision if necessary
|
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
|
hidden_states = hidden_states.to(self.weight.dtype)
|
|
|
|
return self.weight * hidden_states, residual
|
|
elif SYSTEM == "cuda":
|
|
# faster post attention rms norm
|
|
(
|
|
normed_hidden_states,
|
|
res,
|
|
*rest,
|
|
) = dropout_layer_norm.dropout_add_ln_fwd(
|
|
hidden_states,
|
|
residual,
|
|
self.weight,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
0.0,
|
|
self.variance_epsilon,
|
|
1.0,
|
|
0,
|
|
None,
|
|
False,
|
|
True, # Activate RMSNorm
|
|
)
|
|
if res is None:
|
|
res = hidden_states
|
|
|
|
return normed_hidden_states, res
|
|
elif SYSTEM == "rocm":
|
|
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
|
if residual is not None:
|
|
hidden_states += residual
|
|
residual = hidden_states
|
|
|
|
out = torch.empty_like(hidden_states)
|
|
ops.rms_norm(
|
|
out,
|
|
hidden_states,
|
|
self.weight.data,
|
|
self.variance_epsilon,
|
|
)
|
|
return out, residual
|
|
else:
|
|
raise ValueError(
|
|
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
|
)
|