From 1b4c8b4b3e78bb855316c1c887420e7613ecec80 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 19 Apr 2024 11:50:01 +0000 Subject: [PATCH] _custom_C.LLMM1 and HIP_FORCE_DEV_KERNARG=1 --- Dockerfile_amd | 3 +- server/text_generation_server/utils/layers.py | 62 ++++++++++++++++++- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index c532bae9..609dddb1 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -141,7 +141,8 @@ FROM base as base-copy # Text Generation Inference base env ENV HUGGINGFACE_HUB_CACHE=/data \ HF_HUB_ENABLE_HF_TRANSFER=1 \ - PORT=80 + PORT=80 \ + HIP_FORCE_DEV_KERNARG=1 # Copy builds artifacts from triton builder COPY --from=triton-builder /usr/src/triton/python/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 44d593e1..2f1a2b64 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -67,6 +67,11 @@ try: except ImportError: pass +if IS_ROCM_SYSTEM: + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") # Monkey patching @classmethod @@ -324,10 +329,65 @@ def warn_deprecate_bnb(): "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" ) +class FastLinearROCm(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.weight = nn.Parameter(weight) + if bias is not None: + self.bias = nn.Parameter(bias) + else: + self.bias = None + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls(weight, bias) + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + weight = self.weight + bias = self.bias + + if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1: + batched = False + + if inp.dim() == 3: + inp = inp.view(-1, inp.size(-1)) + batched = True + + m, k = weight.shape[0], inp.shape[1] + out = torch.empty(inp.shape[0], + weight.shape[0], + dtype=inp.dtype, + device='cuda') + if (k == 8192 and + (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): + _custom_C.LLMM1(weight, inp, out, 8) + elif k <= 8192 and k % 8 == 0 and m % 4 == 0: + _custom_C.LLMM1(weight, inp, out, 4) + else: + out = F.linear(inp, weight) + if batched: + out = out.view(inp.shape[0], inp.shape[1], weight.shape[0]) + if bias is not None: + out = out + bias + return out + return F.linear(inp, self.weight, self.bias) + def get_linear(weight, bias, quantize): if quantize is None: - linear = FastLinear(weight, bias) + if IS_ROCM_SYSTEM: + linear = FastLinearROCm(weight, bias) + else: + linear = FastLinear(weight, bias) elif quantize == "eetq": if HAS_EETQ: linear = EETQLinear(weight, bias)