mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Enable FP8 Per-Tensor Scales and Integrate Marlin/MoE Kernels Repo for ROCm (#2825)
* (feat) convert tscales to tensorwise * (fix) fp8 scaling for cuda * (kernel) add marlin-kernels * add moe-kernels * fix moe kernel comit * fix scaling * nm changes
This commit is contained in:
parent
880ab9c2f3
commit
e07acc7f68
@ -268,6 +268,15 @@ COPY server/exllamav2_kernels/ .
|
|||||||
|
|
||||||
RUN python setup.py build
|
RUN python setup.py build
|
||||||
|
|
||||||
|
FROM kernel-builder AS marlin-kernels
|
||||||
|
WORKDIR /usr/src
|
||||||
|
ENV MARLIN_KERNELS_BRANCH=v0.3.6
|
||||||
|
ENV VLLM_TARGET_DEVICE=rocm
|
||||||
|
RUN git clone https://github.com/danieldk/marlin-kernels.git && \
|
||||||
|
cd marlin-kernels && \
|
||||||
|
git checkout ${MARLIN_KERNELS_BRANCH} && \
|
||||||
|
python setup.py install
|
||||||
|
|
||||||
FROM kernel-builder AS moe-kernels
|
FROM kernel-builder AS moe-kernels
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd
|
ENV MOE_KERNELS_BRANCH=a67b35841774b2056a73806c36661134b5054edd
|
||||||
@ -299,6 +308,9 @@ COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311
|
|||||||
# Copy build artifacts from exllamav2 kernels builder
|
# Copy build artifacts from exllamav2 kernels builder
|
||||||
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from marlin kernels
|
||||||
|
COPY --from=marlin-kernels /usr/src/marlin-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from moe kernels
|
# Copy build artifacts from moe kernels
|
||||||
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
COPY --from=moe-kernels /usr/src/moe-kernels/build/lib.linux-x86_64-cpython-311 /opt/conda/lib/python3.11/site-packages
|
||||||
|
|
||||||
|
@ -163,17 +163,17 @@ def paged_attention(
|
|||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
assert _PARTITION_SIZE % block_size == 0
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
tmp_output = torch.empty(
|
tmp_output = torch.zeros(
|
||||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
dtype=out.dtype,
|
dtype=out.dtype,
|
||||||
device=out.device,
|
device=out.device,
|
||||||
)
|
)
|
||||||
exp_sums = torch.empty(
|
exp_sums = torch.zeros(
|
||||||
size=(num_seqs, num_heads, max_num_partitions),
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=out.device,
|
device=out.device,
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.zeros_like(exp_sums)
|
||||||
|
|
||||||
if not use_custom:
|
if not use_custom:
|
||||||
ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
|
@ -3,8 +3,14 @@ from typing import List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||||
|
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight, _load_scalar_or_matrix_scale
|
from text_generation_server.layers.fp8 import (
|
||||||
|
Fp8Weight,
|
||||||
|
_load_scalar_or_matrix_scale,
|
||||||
|
requantize_with_max_scale,
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
|
)
|
||||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class W8ANFpLoader(WeightsLoader):
|
class W8ANFpLoader(WeightsLoader):
|
||||||
@ -47,11 +53,10 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
|
|
||||||
weight_scale = None
|
weight_scale = None
|
||||||
if self.load_weight_scale:
|
if self.load_weight_scale:
|
||||||
weight_scale = (
|
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
||||||
.reshape(-1)
|
if SYSTEM == "cuda":
|
||||||
.expand(w.shape[0])
|
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||||
)
|
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if self.load_input_scale:
|
if self.load_input_scale:
|
||||||
@ -87,7 +92,8 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
if SYSTEM == "cuda":
|
||||||
|
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if self.load_input_scale:
|
if self.load_input_scale:
|
||||||
@ -141,6 +147,17 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.load_weight_scale or SYSTEM == "rocm":
|
||||||
|
w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
w, weight_scale, input_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if weight_scale.numel() == len(prefixes):
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, weight_scale = requantize_with_max_scale(
|
||||||
|
w, weight_scale.to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=weight_scale,
|
weight_scale=weight_scale,
|
||||||
@ -153,11 +170,10 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
weight_scale = None
|
weight_scale = None
|
||||||
if self.load_weight_scale:
|
if self.load_weight_scale:
|
||||||
weight_scale = (
|
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
||||||
.reshape(-1)
|
if SYSTEM == "cuda":
|
||||||
.expand(w.shape[0])
|
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||||
)
|
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if self.load_input_scale:
|
if self.load_input_scale:
|
||||||
|
@ -19,6 +19,9 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
marlin_kernels = None
|
marlin_kernels = None
|
||||||
|
|
||||||
|
quant_dtype: torch.dtype = (
|
||||||
|
torch.float8_e4m3fnuz if SYSTEM == "rocm" else torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
if SYSTEM == "cuda" and marlin_kernels is not None:
|
if SYSTEM == "cuda" and marlin_kernels is not None:
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
@ -60,25 +63,58 @@ def normalize_e4m3fn_to_e4m3fnuz(
|
|||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
input_scale: Optional[torch.Tensor] = None,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
assert weight.dtype == torch.float8_e4m3fn
|
if weight.dtype == torch.float8_e4m3fn:
|
||||||
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||||
# but NaN in e4m3fnuz. So here we set it to 0.
|
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||||
# https://onnx.ai/onnx/technical/float8.html
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
weight_as_int8 = weight.view(torch.int8)
|
weight_as_int8 = weight.view(torch.int8)
|
||||||
ROCM_FP8_NAN_AS_INT = -128
|
ROCM_FP8_NAN_AS_INT = -128
|
||||||
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
||||||
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
||||||
|
|
||||||
# For the same bits representation, e4m3fnuz value is half of
|
# For the same bits representation, e4m3fnuz value is half of
|
||||||
# the e4m3fn value, so we should double the scaling factor to
|
# the e4m3fn value, so we should double the scaling factor to
|
||||||
# get the same dequantized value.
|
# get the same dequantized value.
|
||||||
# https://onnx.ai/onnx/technical/float8.html
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
weight_scale = weight_scale * 2.0
|
weight_scale = weight_scale * 2.0
|
||||||
if input_scale is not None:
|
if input_scale is not None:
|
||||||
input_scale = input_scale * 2.0
|
input_scale = input_scale * 2.0
|
||||||
return weight, weight_scale, input_scale
|
return weight, weight_scale, input_scale
|
||||||
|
|
||||||
|
|
||||||
|
def per_tensor_dequantize(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
inv_scale: Union[float, torch.Tensor],
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
fake_qweight = tensor.to(dtype)
|
||||||
|
dq_weight = fake_qweight * inv_scale
|
||||||
|
return dq_weight
|
||||||
|
|
||||||
|
|
||||||
|
def requantize_with_max_scale(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
logical_widths: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Max scale to be used for requanitzation.
|
||||||
|
max_w_scale = weight_scale.max().float()
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for idx, logical_width in enumerate(logical_widths):
|
||||||
|
end = start + logical_width
|
||||||
|
weight_dq = per_tensor_dequantize(
|
||||||
|
weight[start:end, :], weight_scale[idx], dtype
|
||||||
|
)
|
||||||
|
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
||||||
|
weight_dq, max_w_scale
|
||||||
|
)
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return weight, max_w_scale_normalized
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(
|
def fp8_quantize(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
scale: Optional[torch.Tensor] = None,
|
scale: Optional[torch.Tensor] = None,
|
||||||
@ -96,7 +132,7 @@ def fp8_quantize(
|
|||||||
shape = weight.shape
|
shape = weight.shape
|
||||||
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
qweight, scale = marlin_kernels.scaled_fp8_quant(
|
||||||
weight.reshape(-1, shape[-1]),
|
weight.reshape(-1, shape[-1]),
|
||||||
dtype=qdtype,
|
dtype=quant_dtype,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
scale_ub=scale_upper_bound,
|
scale_ub=scale_upper_bound,
|
||||||
# TODO: don't do this when we have to use the Torch kernel.
|
# TODO: don't do this when we have to use the Torch kernel.
|
||||||
@ -116,6 +152,8 @@ def fp8_quantize(
|
|||||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
scale = scale.float().reciprocal()
|
scale = scale.float().reciprocal()
|
||||||
else:
|
else:
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
scale = scale / 2.0
|
||||||
# Use reciprocal to avoid more expensive division.
|
# Use reciprocal to avoid more expensive division.
|
||||||
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
|
||||||
@ -141,17 +179,18 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = (
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
||||||
.reshape(-1)
|
if SYSTEM == "cuda":
|
||||||
.expand(w.shape[0])
|
scale.reshape(-1).expand(w.shape[0])
|
||||||
)
|
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
input_scale = weights.get_tensor(
|
input_scale = (
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
).reshape(-1)
|
.reshape(-1)
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -178,6 +217,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
if scale.numel() > 1:
|
if scale.numel() > 1:
|
||||||
scale = weights.get_packed_sharded(
|
scale = weights.get_packed_sharded(
|
||||||
f"{prefix}.weight_scale",
|
f"{prefix}.weight_scale",
|
||||||
@ -185,7 +225,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
scale = scale.reshape(-1).expand(w.shape[0])
|
if SYSTEM == "cuda":
|
||||||
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
@ -243,6 +284,17 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
w, scale, input_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if scale.numel() == len(prefixes):
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.to(weights.device), logical_widths, weights.dtype
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
@ -259,16 +311,18 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = (
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
||||||
.reshape(-1)
|
if SYSTEM == "cuda":
|
||||||
.expand(w.shape[0])
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
)
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
input_scale = weights.get_tensor(
|
input_scale = (
|
||||||
f"{prefix}.input_scale", to_dtype=False
|
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||||
).reshape(-1)
|
.reshape(-1)
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
@ -326,7 +380,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
if CUTLASS_FP8_AVAILABLE:
|
if CUTLASS_FP8_AVAILABLE:
|
||||||
log_once(logger.info, "Using cutlass w8a8 kernels")
|
log_once(logger.info, "Using cutlass w8a8 kernels")
|
||||||
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
|
||||||
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=qweight, weight_scale=scale
|
weight=qweight, weight_scale=scale
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -443,6 +497,9 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
|
|
||||||
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||||
scale = weights.get_tensor(prefix, to_dtype=False)
|
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||||
|
|
||||||
if scale.numel() > 1:
|
if scale.numel() > 1:
|
||||||
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||||
|
elif SYSTEM == "rocm":
|
||||||
|
return scale.reshape(-1)
|
||||||
return scale.reshape(-1).expand(shape[0])
|
return scale.reshape(-1).expand(shape[0])
|
||||||
|
Loading…
Reference in New Issue
Block a user