From 2264702c016f4398618c03f7f76c9692a0fad186 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Mon, 9 Dec 2024 10:30:03 +0000 Subject: [PATCH] (kernel) add marlin-kernels --- Dockerfile_amd | 12 ++++++++++++ server/text_generation_server/layers/fp8.py | 7 +++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/Dockerfile_amd b/Dockerfile_amd index 7638947a..77d4e613 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -267,6 +267,15 @@ COPY server/exllamav2_kernels/ . RUN python setup.py build +FROM kernel-builder AS marlin-kernels +WORKDIR /usr/src +ENV MARLIN_KERNELS_BRANCH=v0.3.6 +ENV VLLM_TARGET_DEVICE=rocm +RUN git clone https://github.com/danieldk/marlin-kernels.git && \ + cd marlin-kernels && \ + git checkout ${MARLIN_KERNELS_BRANCH} && \ + python setup.py install + FROM install_deps AS base-copy # Text Generation Inference base env @@ -289,6 +298,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 # Copy build artifacts from exllamav2 kernels builder COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages +# Copy build artifacts from marlin kernels +COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages + # Install server COPY proto proto COPY server server diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 32b6cdd6..d7fb64ba 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -19,6 +19,9 @@ try: except ImportError: marlin_kernels = None +quant_dtype: torch.dtype = ( + torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn +) if SYSTEM == "cuda" and marlin_kernels is not None: major, minor = torch.cuda.get_device_capability() @@ -91,7 +94,7 @@ def requantize_with_max_scale( weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int ) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. - max_w_scale = weight_scale.max() + max_w_scale = weight_scale.max().float() start = 0 for idx, logical_width in enumerate(logical_widths): @@ -109,7 +112,7 @@ def fp8_quantize( weight: torch.Tensor, scale: Optional[torch.Tensor] = None, scale_upper_bound: Optional[torch.Tensor] = None, - qdtype: torch.dtype = torch.float8_e4m3fn, + qdtype: torch.dtype = quant_dtype, scalar: bool = False, ): """