(kernel) add marlin-kernels

This commit is contained in:
Mohit Sharma 2024-12-09 10:30:03 +00:00
parent e22cb47fe3
commit 2264702c01
2 changed files with 17 additions and 2 deletions

View File

@ -267,6 +267,15 @@ COPY server/exllamav2_kernels/ .
RUN python setup.py build RUN python setup.py build
FROM kernel-builder AS marlin-kernels
WORKDIR /usr/src
ENV MARLIN_KERNELS_BRANCH=v0.3.6
ENV VLLM_TARGET_DEVICE=rocm
RUN git clone https://github.com/danieldk/marlin-kernels.git && \
cd marlin-kernels && \
git checkout ${MARLIN_KERNELS_BRANCH} && \
python setup.py install
FROM install_deps AS base-copy FROM install_deps AS base-copy
# Text Generation Inference base env # Text Generation Inference base env
@ -289,6 +298,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
# Copy build artifacts from exllamav2 kernels builder # Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from marlin kernels
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server

View File

@ -19,6 +19,9 @@ try:
except ImportError: except ImportError:
marlin_kernels = None marlin_kernels = None
quant_dtype: torch.dtype = (
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
)
if SYSTEM == "cuda" and marlin_kernels is not None: if SYSTEM == "cuda" and marlin_kernels is not None:
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
@ -91,7 +94,7 @@ def requantize_with_max_scale(
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation. # Max scale to be used for requanitzation.
max_w_scale = weight_scale.max() max_w_scale = weight_scale.max().float()
start = 0 start = 0
for idx, logical_width in enumerate(logical_widths): for idx, logical_width in enumerate(logical_widths):
@ -109,7 +112,7 @@ def fp8_quantize(
weight: torch.Tensor, weight: torch.Tensor,
scale: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[torch.Tensor] = None, scale_upper_bound: Optional[torch.Tensor] = None,
qdtype: torch.dtype = torch.float8_e4m3fn, qdtype: torch.dtype = quant_dtype,
scalar: bool = False, scalar: bool = False,
): ):
""" """