mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Adds support for AMD Instinct MI300 in TGI.
Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
3489ce7936
* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1
By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.
Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```
---------
Co-authored-by: Mohit Sharma <mohit21sharma.ms@gmail.com>
53 lines
2.6 KiB
Plaintext
53 lines
2.6 KiB
Plaintext
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
|
|
|
#ifndef _hip_compat_cuh
|
|
#define _hip_compat_cuh
|
|
|
|
// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.
|
|
__device__ __forceinline__ __half __compat_hrcp(__half x) {
|
|
return __half_raw{
|
|
static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};
|
|
}
|
|
|
|
__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {
|
|
return _Float16_2{
|
|
_Float16_2{static_cast<_Float16>(1.0f),
|
|
static_cast<_Float16>(1.0f)} / x.data};
|
|
}
|
|
|
|
#define hrcp __compat_hrcp
|
|
#define h2rcp __compat_h2rcp
|
|
|
|
// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.
|
|
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
|
hipblasOperation_t transA,
|
|
hipblasOperation_t transB,
|
|
int m,
|
|
int n,
|
|
int k,
|
|
const half* alpha,
|
|
const half* AP,
|
|
int lda,
|
|
const half* BP,
|
|
int ldb,
|
|
const half* beta,
|
|
half* CP,
|
|
int ldc) {
|
|
return hipblasHgemm(handle, transA, transB, m, n, k,
|
|
reinterpret_cast<const hipblasHalf *>(alpha),
|
|
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
|
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
|
reinterpret_cast<const hipblasHalf *>(beta),
|
|
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
|
}
|
|
#define hipblasHgemm __compat_hipblasHgemm
|
|
|
|
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
|
#define rocblas_handle hipblasHandle_t
|
|
#define rocblas_operation_none HIPBLAS_OP_N
|
|
#define rocblas_get_stream hipblasGetStream
|
|
#define rocblas_set_stream hipblasSetStream
|
|
#define rocblas_hgemm __compat_hipblasHgemm
|
|
|
|
#endif
|