diff --git a/Dockerfile_amd b/Dockerfile_amd index 7638947a..77d4e613 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -267,6 +267,15 @@ COPY server/exllamav2_kernels/ . RUN python setup.py build +FROM kernel-builder AS marlin-kernels +WORKDIR /usr/src +ENV MARLIN_KERNELS_BRANCH=v0.3.6 +ENV VLLM_TARGET_DEVICE=rocm +RUN git clone https://github.com/danieldk/marlin-kernels.git && \ + cd marlin-kernels && \ + git checkout ${MARLIN_KERNELS_BRANCH} && \ + python setup.py install + FROM install_deps AS base-copy # Text Generation Inference base env @@ -289,6 +298,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 # Copy build artifacts from exllamav2 kernels builder COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages +# Copy build artifacts from marlin kernels +COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages + # Install server COPY proto proto COPY server server diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 32b6cdd6..d7fb64ba 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -19,6 +19,9 @@ try: except ImportError: marlin_kernels = None +quant_dtype: torch.dtype = ( + torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn +) if SYSTEM == "cuda" and marlin_kernels is not None: major, minor = torch.cuda.get_device_capability() @@ -91,7 +94,7 @@ def requantize_with_max_scale( weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int ) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. - max_w_scale = weight_scale.max() + max_w_scale = weight_scale.max().float() start = 0 for idx, logical_width in enumerate(logical_widths): @@ -109,7 +112,7 @@ def fp8_quantize( weight: torch.Tensor, scale: Optional[torch.Tensor] = None, scale_upper_bound: Optional[torch.Tensor] = None, - qdtype: torch.dtype = torch.float8_e4m3fn, + qdtype: torch.dtype = quant_dtype, scalar: bool = False, ): """