(kernel) add marlin-kernels

This commit is contained in:
Mohit Sharma 2024-12-09 10:30:03 +00:00
parent e22cb47fe3
commit 2264702c01
2 changed files with 17 additions and 2 deletions

View File

@ -267,6 +267,15 @@ COPY server/exllamav2_kernels/ .
RUN python setup.py build
FROM kernel-builder AS marlin-kernels
WORKDIR /usr/src
ENV MARLIN_KERNELS_BRANCH=v0.3.6
ENV VLLM_TARGET_DEVICE=rocm
RUN git clone https://github.com/danieldk/marlin-kernels.git && \
cd marlin-kernels && \
git checkout ${MARLIN_KERNELS_BRANCH} && \
python setup.py install
FROM install_deps AS base-copy
# Text Generation Inference base env
@ -289,6 +298,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Copy build artifacts from marlin kernels
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
# Install server
COPY proto proto
COPY server server

View File

@ -19,6 +19,9 @@ try:
except ImportError:
marlin_kernels = None
quant_dtype: torch.dtype = (
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
)
if SYSTEM == "cuda" and marlin_kernels is not None:
major, minor = torch.cuda.get_device_capability()
@ -91,7 +94,7 @@ def requantize_with_max_scale(
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
max_w_scale = weight_scale.max().float()
start = 0
for idx, logical_width in enumerate(logical_widths):
@ -109,7 +112,7 @@ def fp8_quantize(
weight: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[torch.Tensor] = None,
qdtype: torch.dtype = torch.float8_e4m3fn,
qdtype: torch.dtype = quant_dtype,
scalar: bool = False,
):
"""