add default dtype

This commit is contained in:
OlivierDehaene 2024-07-19 19:38:01 +02:00
parent 10cd8ab4a6
commit 081d16cab5
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
4 changed files with 70 additions and 39 deletions

View File

@ -57,14 +57,15 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
@dataclass
class Fp8Weight:
weight: torch.Tensor
dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias)
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
return get_fp8_linear().from_fp8(
self.weight, self.weight_scale, self.input_scale, bias, bias.dtype
self.weight, self.weight_scale, self.input_scale, bias, self.dtype
)
@ -110,7 +111,7 @@ class Fp8Linear(torch.nn.Module):
y = torch.ops.fbgemm.f8f8bf16_rowwise(
qinput,
self.weight,
self.qweight,
scale,
self.scale,
use_fast_accum=True,

View File

@ -530,13 +530,13 @@ class GPTQMarlinFP8Linear(nn.Module):
)
@classmethod
def from_unquant(cls, weight, bias):
def from_unquant(cls, weight, bias, _dtype):
qweight, scale = fp8_quantize(weight)
return cls(qweight=qweight, scale=scale, bias=bias)
@classmethod
def from_fp8(cls, weight, bias):
return cls(qweight=weight.weight, scale=weight.weight_scale, bias=bias)
def from_fp8(cls, weight, scale, _input_scale, bias, _dtype):
return cls(qweight=weight, scale=scale, bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None

View File

@ -311,6 +311,9 @@ def get_model(
if quantize in ["awq", "exl2", "gptq", "marlin"]:
# These quantizers only work with float16 params.
dtype = torch.float16
elif quantize == "fp8":
# gemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16
else:
# Keep it as default for now and let
# every model resolve their own default dtype.

View File

@ -86,6 +86,7 @@ class Weight(ABC):
@dataclass
class UnquantizedWeight:
weight: torch.Tensor
dtype: torch.dtype
def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
@ -137,14 +138,19 @@ class DefaultWeightsLoader(WeightsLoader):
# FP8 branch
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, cast=False
)
input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
dtype=weights.dtype,
)
input_scale = weights.get_tensor(f"{prefix}.input_scale")
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
if self.weight_class is None:
return UnquantizedWeight(w)
return self.weight_class(w)
return UnquantizedWeight(w, dtype=weights.dtype)
return self.weight_class(w, dtype=weights.dtype)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -160,14 +166,22 @@ class DefaultWeightsLoader(WeightsLoader):
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
)
scale = [weights.get_sharded(f"{p}.weight_scale", dim=0) for p in prefixes]
scale = [
weights.get_sharded(f"{p}.weight_scale", dim=0, cast=False)
for p in prefixes
]
scale = torch.cat(scale, dim=0)
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale")
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale", cast=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
dtype=weights.dtype,
)
if self.weight_class is None:
return UnquantizedWeight(w)
return self.weight_class(w)
return UnquantizedWeight(w, dtype=weights.dtype)
return self.weight_class(w, dtype=weights.dtype)
def get_weights_row(self, weights: "Weights", prefix: str):
w = weights.get_sharded(f"{prefix}.weight", dim=1)
@ -181,13 +195,18 @@ class DefaultWeightsLoader(WeightsLoader):
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
)
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0)
input_scale = weights.get_tensor(f"{prefix}.input_scale")
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, cast=False)
input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
input_scale=input_scale,
dtype=weights.dtype,
)
if self.weight_class is None:
return UnquantizedWeight(w)
return self.weight_class(w)
return UnquantizedWeight(w, dtype=weights.dtype)
return self.weight_class(w, dtype=weights.dtype)
class Weights:
@ -261,25 +280,29 @@ class Weights:
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True):
def get_tensor(self, tensor_name: str, to_device=True, cast=True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. Exl2 uses int16
# as well. FP8 uses torch.float8_e4m3fn
if tensor.dtype not in [
if (
tensor.dtype
not in [
torch.float8_e4m3fn,
torch.int16,
torch.int32,
torch.int64,
]:
]
and cast
):
tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device)
return tensor
def get_partial_sharded(self, tensor_name: str, dim: int):
def get_partial_sharded(self, tensor_name: str, dim: int, cast=True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
@ -300,12 +323,12 @@ class Weights:
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. exl2 uses int16.
# FP8 uses torch.float8_e4m3fn.
if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32):
if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) and cast:
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
def get_sharded(self, tensor_name: str, dim: int, cast=True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
@ -314,10 +337,10 @@ class Weights:
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
return self.get_partial_sharded(tensor_name, dim, cast=cast)
def get_packed_sharded(
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]], cast=True
) -> torch.Tensor:
"""
Get a shard from a tensor that packs multiple tensors.
@ -363,12 +386,16 @@ class Weights:
tensor = tensor.to(device=self.device)
# Avoid casting quantizer dtypes.
if tensor.dtype not in [
if (
tensor.dtype
not in [
torch.float8_e4m3fn,
torch.int16,
torch.int32,
torch.int64,
]:
]
and cast
):
tensor = tensor.to(dtype=self.dtype)
return tensor