mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
(fix) fp8 scaling for cuda
This commit is contained in:
parent
e2454dba40
commit
e22cb47fe3
@ -116,17 +116,17 @@ def paged_attention(
|
|||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
assert _PARTITION_SIZE % block_size == 0
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
tmp_output = torch.empty(
|
tmp_output = torch.zeros(
|
||||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
dtype=out.dtype,
|
dtype=out.dtype,
|
||||||
device=out.device,
|
device=out.device,
|
||||||
)
|
)
|
||||||
exp_sums = torch.empty(
|
exp_sums = torch.zeros(
|
||||||
size=(num_seqs, num_heads, max_num_partitions),
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=out.device,
|
device=out.device,
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.zeros_like(exp_sums)
|
||||||
|
|
||||||
if not use_custom:
|
if not use_custom:
|
||||||
ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
|
@ -3,8 +3,13 @@ from typing import List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||||
|
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight, _load_scalar_or_matrix_scale
|
from text_generation_server.layers.fp8 import (
|
||||||
|
Fp8Weight,
|
||||||
|
_load_scalar_or_matrix_scale,
|
||||||
|
requantize_with_max_scale,
|
||||||
|
)
|
||||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class W8ANFpLoader(WeightsLoader):
|
class W8ANFpLoader(WeightsLoader):
|
||||||
@ -47,11 +52,10 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
|
|
||||||
weight_scale = None
|
weight_scale = None
|
||||||
if self.load_weight_scale:
|
if self.load_weight_scale:
|
||||||
weight_scale = (
|
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
||||||
.reshape(-1)
|
if SYSTEM == "cuda":
|
||||||
.expand(w.shape[0])
|
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||||
)
|
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if self.load_input_scale:
|
if self.load_input_scale:
|
||||||
@ -87,7 +91,8 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
if SYSTEM == "cuda":
|
||||||
|
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if self.load_input_scale:
|
if self.load_input_scale:
|
||||||
@ -127,6 +132,12 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
]
|
]
|
||||||
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
|
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
|
if weight_scale.numel() == len(prefixes):
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
|
w, weight_scale = requantize_with_max_scale(
|
||||||
|
w, weight_scale.to(weights.device), logical_widths
|
||||||
|
)
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if self.load_input_scale:
|
if self.load_input_scale:
|
||||||
input_scale = [
|
input_scale = [
|
||||||
@ -153,11 +164,10 @@ class W8ANFpLoader(WeightsLoader):
|
|||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
weight_scale = None
|
weight_scale = None
|
||||||
if self.load_weight_scale:
|
if self.load_weight_scale:
|
||||||
weight_scale = (
|
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
||||||
.reshape(-1)
|
if SYSTEM == "cuda":
|
||||||
.expand(w.shape[0])
|
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
|
||||||
)
|
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if self.load_input_scale:
|
if self.load_input_scale:
|
||||||
|
@ -167,11 +167,10 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = (
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
||||||
.reshape(-1)
|
if SYSTEM == "cuda":
|
||||||
.expand(w.shape[0])
|
scale.reshape(-1).expand(w.shape[0])
|
||||||
).max()
|
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
@ -206,6 +205,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
|
||||||
if scale.numel() > 1:
|
if scale.numel() > 1:
|
||||||
scale = weights.get_packed_sharded(
|
scale = weights.get_packed_sharded(
|
||||||
f"{prefix}.weight_scale",
|
f"{prefix}.weight_scale",
|
||||||
@ -213,7 +213,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
block_sizes=block_sizes,
|
block_sizes=block_sizes,
|
||||||
to_dtype=False,
|
to_dtype=False,
|
||||||
)
|
)
|
||||||
scale = scale.reshape(-1).expand(w.shape[0]).max()
|
if SYSTEM == "cuda":
|
||||||
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
@ -255,15 +256,15 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = [
|
scale = [
|
||||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||||
.max()
|
|
||||||
.unsqueeze(0)
|
|
||||||
for p, shape in zip(prefixes, shapes)
|
for p, shape in zip(prefixes, shapes)
|
||||||
]
|
]
|
||||||
scale = torch.cat(scale).to(weights.device)
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
|
|
||||||
logical_widths = [x[0] for x in shapes]
|
if scale.numel() == len(prefixes):
|
||||||
|
logical_widths = [x[0] for x in shapes]
|
||||||
w, scale = requantize_with_max_scale(w, scale, logical_widths)
|
w, scale = requantize_with_max_scale(
|
||||||
|
w, scale.to(weights.device), logical_widths
|
||||||
|
)
|
||||||
|
|
||||||
input_scale = [
|
input_scale = [
|
||||||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||||
@ -293,11 +294,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = (
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
|
||||||
.reshape(-1)
|
if SYSTEM == "cuda":
|
||||||
.expand(w.shape[0])
|
scale = scale.reshape(-1).expand(w.shape[0])
|
||||||
).max()
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||||
input_scale = (
|
input_scale = (
|
||||||
@ -479,6 +480,9 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
|
|
||||||
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||||
scale = weights.get_tensor(prefix, to_dtype=False)
|
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||||
|
|
||||||
if scale.numel() > 1:
|
if scale.numel() > 1:
|
||||||
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||||
|
elif SYSTEM == "rocm":
|
||||||
|
return scale.reshape(-1)
|
||||||
return scale.reshape(-1).expand(shape[0])
|
return scale.reshape(-1).expand(shape[0])
|
||||||
|
Loading…
Reference in New Issue
Block a user