mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
(kernel) add marlin-kernels
This commit is contained in:
parent
e22cb47fe3
commit
2264702c01
@ -267,6 +267,15 @@ COPY server/exllamav2_kernels/ .
|
|||||||
|
|
||||||
RUN python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
|
FROM kernel-builder AS marlin-kernels
|
||||||
|
WORKDIR /usr/src
|
||||||
|
ENV MARLIN_KERNELS_BRANCH=v0.3.6
|
||||||
|
ENV VLLM_TARGET_DEVICE=rocm
|
||||||
|
RUN git clone https://github.com/danieldk/marlin-kernels.git && \
|
||||||
|
cd marlin-kernels && \
|
||||||
|
git checkout ${MARLIN_KERNELS_BRANCH} && \
|
||||||
|
python setup.py install
|
||||||
|
|
||||||
FROM install_deps AS base-copy
|
FROM install_deps AS base-copy
|
||||||
|
|
||||||
# Text Generation Inference base env
|
# Text Generation Inference base env
|
||||||
@ -289,6 +298,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
|
|||||||
# Copy build artifacts from exllamav2 kernels builder
|
# Copy build artifacts from exllamav2 kernels builder
|
||||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from marlin kernels
|
||||||
|
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
COPY server server
|
COPY server server
|
||||||
|
@ -19,6 +19,9 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
quant_dtype: torch.dtype = (
|
||||||
|
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
if SYSTEM == "cuda" and marlin_kernels is not None:
|
if SYSTEM == "cuda" and marlin_kernels is not None:
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
@ -91,7 +94,7 @@ def requantize_with_max_scale(
|
|||||||
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int
|
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: int
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Max scale to be used for requanitzation.
|
# Max scale to be used for requanitzation.
|
||||||
max_w_scale = weight_scale.max()
|
max_w_scale = weight_scale.max().float()
|
||||||
|
|
||||||
start = 0
|
start = 0
|
||||||
for idx, logical_width in enumerate(logical_widths):
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
@ -109,7 +112,7 @@ def fp8_quantize(
|
|||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
scale: Optional[torch.Tensor] = None,
|
scale: Optional[torch.Tensor] = None,
|
||||||
scale_upper_bound: Optional[torch.Tensor] = None,
|
scale_upper_bound: Optional[torch.Tensor] = None,
|
||||||
qdtype: torch.dtype = torch.float8_e4m3fn,
|
qdtype: torch.dtype = quant_dtype,
|
||||||
scalar: bool = False,
|
scalar: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user