mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 12:54:52 +00:00
add default dtype
This commit is contained in:
parent
10cd8ab4a6
commit
081d16cab5
@ -57,14 +57,15 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Fp8Weight:
|
class Fp8Weight:
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
dtype: torch.dtype
|
||||||
weight_scale: Optional[torch.Tensor] = None
|
weight_scale: Optional[torch.Tensor] = None
|
||||||
input_scale: Optional[torch.Tensor] = None
|
input_scale: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
if self.weight_scale is None:
|
if self.weight_scale is None:
|
||||||
return get_fp8_linear().from_unquant(self.weight, bias)
|
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
||||||
return get_fp8_linear().from_fp8(
|
return get_fp8_linear().from_fp8(
|
||||||
self.weight, self.weight_scale, self.input_scale, bias, bias.dtype
|
self.weight, self.weight_scale, self.input_scale, bias, self.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -110,7 +111,7 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
|
|
||||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||||
qinput,
|
qinput,
|
||||||
self.weight,
|
self.qweight,
|
||||||
scale,
|
scale,
|
||||||
self.scale,
|
self.scale,
|
||||||
use_fast_accum=True,
|
use_fast_accum=True,
|
||||||
|
@ -530,13 +530,13 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_unquant(cls, weight, bias):
|
def from_unquant(cls, weight, bias, _dtype):
|
||||||
qweight, scale = fp8_quantize(weight)
|
qweight, scale = fp8_quantize(weight)
|
||||||
return cls(qweight=qweight, scale=scale, bias=bias)
|
return cls(qweight=qweight, scale=scale, bias=bias)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fp8(cls, weight, bias):
|
def from_fp8(cls, weight, scale, _input_scale, bias, _dtype):
|
||||||
return cls(qweight=weight.weight, scale=weight.weight_scale, bias=bias)
|
return cls(qweight=weight, scale=scale, bias=bias)
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
assert marlin_kernels is not None
|
assert marlin_kernels is not None
|
||||||
|
@ -311,6 +311,9 @@ def get_model(
|
|||||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
# These quantizers only work with float16 params.
|
# These quantizers only work with float16 params.
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
elif quantize == "fp8":
|
||||||
|
# gemm kernels are fp8xfp8->bf16
|
||||||
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
# Keep it as default for now and let
|
# Keep it as default for now and let
|
||||||
# every model resolve their own default dtype.
|
# every model resolve their own default dtype.
|
||||||
|
@ -86,6 +86,7 @@ class Weight(ABC):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class UnquantizedWeight:
|
class UnquantizedWeight:
|
||||||
weight: torch.Tensor
|
weight: torch.Tensor
|
||||||
|
dtype: torch.dtype
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
|
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
|
||||||
@ -137,14 +138,19 @@ class DefaultWeightsLoader(WeightsLoader):
|
|||||||
|
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_packed_sharded(
|
scale = weights.get_packed_sharded(
|
||||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
|
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, cast=False
|
||||||
|
)
|
||||||
|
input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
input_scale = weights.get_tensor(f"{prefix}.input_scale")
|
|
||||||
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
|
||||||
|
|
||||||
if self.weight_class is None:
|
if self.weight_class is None:
|
||||||
return UnquantizedWeight(w)
|
return UnquantizedWeight(w, dtype=weights.dtype)
|
||||||
return self.weight_class(w)
|
return self.weight_class(w, dtype=weights.dtype)
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
@ -160,14 +166,22 @@ class DefaultWeightsLoader(WeightsLoader):
|
|||||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||||
)
|
)
|
||||||
|
|
||||||
scale = [weights.get_sharded(f"{p}.weight_scale", dim=0) for p in prefixes]
|
scale = [
|
||||||
|
weights.get_sharded(f"{p}.weight_scale", dim=0, cast=False)
|
||||||
|
for p in prefixes
|
||||||
|
]
|
||||||
scale = torch.cat(scale, dim=0)
|
scale = torch.cat(scale, dim=0)
|
||||||
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale")
|
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale", cast=False)
|
||||||
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
if self.weight_class is None:
|
if self.weight_class is None:
|
||||||
return UnquantizedWeight(w)
|
return UnquantizedWeight(w, dtype=weights.dtype)
|
||||||
return self.weight_class(w)
|
return self.weight_class(w, dtype=weights.dtype)
|
||||||
|
|
||||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
@ -181,13 +195,18 @@ class DefaultWeightsLoader(WeightsLoader):
|
|||||||
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
f"Deserialized quantised fp8 weights but weight class is {self.weight_class}"
|
||||||
)
|
)
|
||||||
|
|
||||||
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0)
|
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, cast=False)
|
||||||
input_scale = weights.get_tensor(f"{prefix}.input_scale")
|
input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False)
|
||||||
return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
input_scale=input_scale,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
if self.weight_class is None:
|
if self.weight_class is None:
|
||||||
return UnquantizedWeight(w)
|
return UnquantizedWeight(w, dtype=weights.dtype)
|
||||||
return self.weight_class(w)
|
return self.weight_class(w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
@ -261,25 +280,29 @@ class Weights:
|
|||||||
def get_shape(self, tensor_name: str):
|
def get_shape(self, tensor_name: str):
|
||||||
return self._get_slice(tensor_name).get_shape()
|
return self._get_slice(tensor_name).get_shape()
|
||||||
|
|
||||||
def get_tensor(self, tensor_name: str, to_device=True):
|
def get_tensor(self, tensor_name: str, to_device=True, cast=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
tensor = f.get_tensor(tensor_name)
|
tensor = f.get_tensor(tensor_name)
|
||||||
# Special case for gptq which shouldn't convert
|
# Special case for gptq which shouldn't convert
|
||||||
# u4 which are disguised as int32. Exl2 uses int16
|
# u4 which are disguised as int32. Exl2 uses int16
|
||||||
# as well. FP8 uses torch.float8_e4m3fn
|
# as well. FP8 uses torch.float8_e4m3fn
|
||||||
if tensor.dtype not in [
|
if (
|
||||||
|
tensor.dtype
|
||||||
|
not in [
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
torch.int16,
|
torch.int16,
|
||||||
torch.int32,
|
torch.int32,
|
||||||
torch.int64,
|
torch.int64,
|
||||||
]:
|
]
|
||||||
|
and cast
|
||||||
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
if to_device:
|
if to_device:
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_partial_sharded(self, tensor_name: str, dim: int):
|
def get_partial_sharded(self, tensor_name: str, dim: int, cast=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
@ -300,12 +323,12 @@ class Weights:
|
|||||||
# Special case for gptq which shouldn't convert
|
# Special case for gptq which shouldn't convert
|
||||||
# u4 which are disguised as int32. exl2 uses int16.
|
# u4 which are disguised as int32. exl2 uses int16.
|
||||||
# FP8 uses torch.float8_e4m3fn.
|
# FP8 uses torch.float8_e4m3fn.
|
||||||
if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32):
|
if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) and cast:
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_sharded(self, tensor_name: str, dim: int):
|
def get_sharded(self, tensor_name: str, dim: int, cast=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
@ -314,10 +337,10 @@ class Weights:
|
|||||||
assert (
|
assert (
|
||||||
size % world_size == 0
|
size % world_size == 0
|
||||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
return self.get_partial_sharded(tensor_name, dim)
|
return self.get_partial_sharded(tensor_name, dim, cast=cast)
|
||||||
|
|
||||||
def get_packed_sharded(
|
def get_packed_sharded(
|
||||||
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]]
|
self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]], cast=True
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Get a shard from a tensor that packs multiple tensors.
|
Get a shard from a tensor that packs multiple tensors.
|
||||||
@ -363,12 +386,16 @@ class Weights:
|
|||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
|
|
||||||
# Avoid casting quantizer dtypes.
|
# Avoid casting quantizer dtypes.
|
||||||
if tensor.dtype not in [
|
if (
|
||||||
|
tensor.dtype
|
||||||
|
not in [
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
torch.int16,
|
torch.int16,
|
||||||
torch.int32,
|
torch.int32,
|
||||||
torch.int64,
|
torch.int64,
|
||||||
]:
|
]
|
||||||
|
and cast
|
||||||
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
|
|
||||||
return tensor
|
return tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user