mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
(kernel) add marlin-kernels
This commit is contained in:
parent
e22cb47fe3
commit
2264702c01
@ -267,6 +267,15 @@ COPY server/exllamav2_kernels/ .
|
||||
|
||||
RUN python setup.py build
|
||||
|
||||
FROM kernel-builder AS marlin-kernels
|
||||
WORKDIR /usr/src
|
||||
ENV MARLIN_KERNELS_BRANCH=v0.3.6
|
||||
ENV VLLM_TARGET_DEVICE=rocm
|
||||
RUN git clone https://github.com/danieldk/marlin-kernels.git && \
|
||||
cd marlin-kernels && \
|
||||
git checkout ${MARLIN_KERNELS_BRANCH} && \
|
||||
python setup.py install
|
||||
|
||||
FROM install_deps AS base-copy
|
||||
|
||||
# Text Generation Inference base env
|
||||
@ -289,6 +298,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
|
||||
# Copy build artifacts from exllamav2 kernels builder
|
||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Copy build artifacts from marlin kernels
|
||||
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
COPY server server
|
||||
|
@ -19,6 +19,9 @@ try:
|
||||
except ImportError:
|
||||
marlin_kernels = None
|
||||
|
||||
quant_dtype: torch.dtype = (
|
||||
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
if SYSTEM == "cuda" and marlin_kernels is not None:
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
@ -91,7 +94,7 @@ def requantize_with_max_scale(
|
||||
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Max scale to be used for requanitzation.
|
||||
max_w_scale = weight_scale.max()
|
||||
max_w_scale = weight_scale.max().float()
|
||||
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(logical_widths):
|
||||
@ -109,7 +112,7 @@ def fp8_quantize(
|
||||
weight: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
scale_upper_bound: Optional[torch.Tensor] = None,
|
||||
qdtype: torch.dtype = torch.float8_e4m3fn,
|
||||
qdtype: torch.dtype = quant_dtype,
|
||||
scalar: bool = False,
|
||||
):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user