mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
3489ce7936
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
217 lines
7.0 KiB
Python
217 lines
7.0 KiB
Python
import torch
|
|
from torch.nn import functional as F
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
|
|
|
if SYSTEM == "rocm":
|
|
try:
|
|
from vllm import _custom_C
|
|
except Exception as e:
|
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
|
|
|
|
|
class FastLinear(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
weight,
|
|
bias,
|
|
) -> None:
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
if bias is not None:
|
|
self.bias = torch.nn.Parameter(bias, requires_grad=False)
|
|
else:
|
|
self.bias = None
|
|
|
|
@classmethod
|
|
def load(cls, config, prefix: str, weights, bias: bool):
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
if bias:
|
|
bias = weights.get_tensor(f"{prefix}.bias")
|
|
else:
|
|
bias = None
|
|
return cls(weight, bias)
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
return F.linear(input, self.weight, self.bias)
|
|
|
|
|
|
class FastLinearROCm(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
weight,
|
|
bias,
|
|
) -> None:
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(weight)
|
|
if bias is not None:
|
|
self.bias = torch.nn.Parameter(bias)
|
|
else:
|
|
self.bias = None
|
|
|
|
@classmethod
|
|
def load(cls, config, prefix: str, weights, bias: bool):
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
if bias:
|
|
bias = weights.get_tensor(f"{prefix}.bias")
|
|
else:
|
|
bias = None
|
|
return cls(weight, bias)
|
|
|
|
def forward(self, inp: torch.Tensor) -> torch.Tensor:
|
|
weight = self.weight
|
|
bias = self.bias
|
|
|
|
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
|
|
batched = False
|
|
inp_shape = inp.shape
|
|
|
|
if inp.dim() == 3:
|
|
inp = inp.view(-1, inp_shape[-1])
|
|
batched = True
|
|
|
|
m, k = weight.shape[0], inp_shape[1]
|
|
out = torch.empty(
|
|
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
|
)
|
|
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
|
_custom_C.LLMM1(weight, inp, out, 8)
|
|
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
|
|
_custom_C.LLMM1(weight, inp, out, 4)
|
|
else:
|
|
out = F.linear(inp, weight)
|
|
|
|
if batched:
|
|
out.view(*inp_shape[:-1], out.shape[-1])
|
|
|
|
if bias is not None:
|
|
out = out + bias
|
|
return out
|
|
return F.linear(inp, self.weight, self.bias)
|
|
|
|
|
|
def get_linear(weight, bias, quantize):
|
|
if quantize is None:
|
|
if SYSTEM == "rocm":
|
|
linear = FastLinearROCm(weight, bias)
|
|
else:
|
|
linear = FastLinear(weight, bias)
|
|
elif quantize == "eetq":
|
|
try:
|
|
from text_generation_server.layers.eetq import EETQLinear
|
|
|
|
linear = EETQLinear(weight, bias)
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
|
)
|
|
elif quantize == "fp8":
|
|
from text_generation_server.layers.fp8 import Fp8Linear
|
|
|
|
linear = Fp8Linear(weight, bias)
|
|
elif quantize == "bitsandbytes":
|
|
try:
|
|
from text_generation_server.layers.bnb import (
|
|
warn_deprecate_bnb,
|
|
Linear8bitLt,
|
|
)
|
|
except ImportError:
|
|
raise NotImplementedError(
|
|
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
|
|
)
|
|
warn_deprecate_bnb()
|
|
linear = Linear8bitLt(
|
|
weight,
|
|
bias,
|
|
has_fp16_weights=False,
|
|
threshold=6.0,
|
|
)
|
|
if bias is not None:
|
|
linear.bias = nn.Parameter(bias)
|
|
elif quantize == "bitsandbytes-fp4":
|
|
try:
|
|
from text_generation_server.layers.bnb import Linear4bit
|
|
except ImportError:
|
|
raise NotImplementedError(
|
|
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
|
|
)
|
|
linear = Linear4bit(
|
|
weight,
|
|
bias,
|
|
quant_type="fp4",
|
|
)
|
|
elif quantize == "bitsandbytes-nf4":
|
|
try:
|
|
from text_generation_server.layers.bnb import Linear4bit
|
|
except ImportError:
|
|
raise NotImplementedError(
|
|
f"Bitsandbytes is missing install it with `pip install bitsandbytes`."
|
|
)
|
|
linear = Linear4bit(
|
|
weight,
|
|
bias,
|
|
quant_type="nf4",
|
|
)
|
|
elif quantize == "gptq":
|
|
try:
|
|
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
|
|
except Exception:
|
|
raise NotImplementedError(
|
|
f"The passed weight is not `gptq` compatible, loader needs to be updated."
|
|
)
|
|
|
|
if use_exllama:
|
|
try:
|
|
from text_generation_server.layers.gptq import (
|
|
ExllamaQuantLinear,
|
|
)
|
|
except ImportError:
|
|
raise NotImplementedError(
|
|
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
|
|
)
|
|
|
|
linear = ExllamaQuantLinear(
|
|
qweight, qzeros, scales, g_idx, bias, bits, groupsize
|
|
)
|
|
else:
|
|
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
|
|
|
linear = QuantLinear(
|
|
qweight,
|
|
qzeros,
|
|
scales,
|
|
g_idx,
|
|
bias,
|
|
bits,
|
|
groupsize,
|
|
)
|
|
elif quantize == "awq":
|
|
try:
|
|
qweight, qzeros, scales, _, bits, groupsize, _ = weight
|
|
except Exception:
|
|
raise NotImplementedError(
|
|
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
|
)
|
|
if SYSTEM == "rocm":
|
|
raise NotImplementedError(
|
|
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
|
"to use Exllama/GPTQ kernels for AWQ inference."
|
|
)
|
|
try:
|
|
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
|
|
|
|
linear = WQLinear(
|
|
w_bit=bits,
|
|
group_size=groupsize,
|
|
qweight=qweight,
|
|
qzeros=qzeros,
|
|
scales=scales,
|
|
bias=bias is not None,
|
|
)
|
|
except ImportError:
|
|
raise NotImplementedError(
|
|
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
|
return linear
|