diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 0bf3eeeb..fe083b68 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -57,14 +57,15 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): @dataclass class Fp8Weight: weight: torch.Tensor + dtype: torch.dtype weight_scale: Optional[torch.Tensor] = None input_scale: Optional[torch.Tensor] = None def get_linear(self, bias: torch.Tensor): if self.weight_scale is None: - return get_fp8_linear().from_unquant(self.weight, bias) + return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) return get_fp8_linear().from_fp8( - self.weight, self.weight_scale, self.input_scale, bias, bias.dtype + self.weight, self.weight_scale, self.input_scale, bias, self.dtype ) @@ -110,7 +111,7 @@ class Fp8Linear(torch.nn.Module): y = torch.ops.fbgemm.f8f8bf16_rowwise( qinput, - self.weight, + self.qweight, scale, self.scale, use_fast_accum=True, diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index 542d9a35..40271c35 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -530,13 +530,13 @@ class GPTQMarlinFP8Linear(nn.Module): ) @classmethod - def from_unquant(cls, weight, bias): + def from_unquant(cls, weight, bias, _dtype): qweight, scale = fp8_quantize(weight) return cls(qweight=qweight, scale=scale, bias=bias) @classmethod - def from_fp8(cls, weight, bias): - return cls(qweight=weight.weight, scale=weight.weight_scale, bias=bias) + def from_fp8(cls, weight, scale, _input_scale, bias, _dtype): + return cls(qweight=weight, scale=scale, bias=bias) def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 690a8887..725aa544 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -311,6 +311,9 @@ def get_model( if quantize in ["awq", "exl2", "gptq", "marlin"]: # These quantizers only work with float16 params. dtype = torch.float16 + elif quantize == "fp8": + # gemm kernels are fp8xfp8->bf16 + dtype = torch.bfloat16 else: # Keep it as default for now and let # every model resolve their own default dtype. diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 491f92ea..108ba6e7 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -86,6 +86,7 @@ class Weight(ABC): @dataclass class UnquantizedWeight: weight: torch.Tensor + dtype: torch.dtype def get_linear(self, bias: torch.Tensor): from text_generation_server.layers.linear import FastLinear, FastLinearROCm @@ -137,14 +138,19 @@ class DefaultWeightsLoader(WeightsLoader): # FP8 branch scale = weights.get_packed_sharded( - f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes + f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, cast=False + ) + input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + input_scale=input_scale, + dtype=weights.dtype, ) - input_scale = weights.get_tensor(f"{prefix}.input_scale") - return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale) if self.weight_class is None: - return UnquantizedWeight(w) - return self.weight_class(w) + return UnquantizedWeight(w, dtype=weights.dtype) + return self.weight_class(w, dtype=weights.dtype) def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] @@ -160,14 +166,22 @@ class DefaultWeightsLoader(WeightsLoader): f"Deserialized quantised fp8 weights but weight class is {self.weight_class}" ) - scale = [weights.get_sharded(f"{p}.weight_scale", dim=0) for p in prefixes] + scale = [ + weights.get_sharded(f"{p}.weight_scale", dim=0, cast=False) + for p in prefixes + ] scale = torch.cat(scale, dim=0) - input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale") - return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale) + input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale", cast=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + input_scale=input_scale, + dtype=weights.dtype, + ) if self.weight_class is None: - return UnquantizedWeight(w) - return self.weight_class(w) + return UnquantizedWeight(w, dtype=weights.dtype) + return self.weight_class(w, dtype=weights.dtype) def get_weights_row(self, weights: "Weights", prefix: str): w = weights.get_sharded(f"{prefix}.weight", dim=1) @@ -181,13 +195,18 @@ class DefaultWeightsLoader(WeightsLoader): f"Deserialized quantised fp8 weights but weight class is {self.weight_class}" ) - scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0) - input_scale = weights.get_tensor(f"{prefix}.input_scale") - return Fp8Weight(weight=w, weight_scale=scale, input_scale=input_scale) + scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, cast=False) + input_scale = weights.get_tensor(f"{prefix}.input_scale", cast=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + input_scale=input_scale, + dtype=weights.dtype, + ) if self.weight_class is None: - return UnquantizedWeight(w) - return self.weight_class(w) + return UnquantizedWeight(w, dtype=weights.dtype) + return self.weight_class(w, dtype=weights.dtype) class Weights: @@ -261,25 +280,29 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True): + def get_tensor(self, tensor_name: str, to_device=True, cast=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert # u4 which are disguised as int32. Exl2 uses int16 # as well. FP8 uses torch.float8_e4m3fn - if tensor.dtype not in [ - torch.float8_e4m3fn, - torch.int16, - torch.int32, - torch.int64, - ]: + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and cast + ): tensor = tensor.to(dtype=self.dtype) if to_device: tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded(self, tensor_name: str, dim: int, cast=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -300,12 +323,12 @@ class Weights: # Special case for gptq which shouldn't convert # u4 which are disguised as int32. exl2 uses int16. # FP8 uses torch.float8_e4m3fn. - if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32): + if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) and cast: tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_sharded(self, tensor_name: str, dim: int, cast=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -314,10 +337,10 @@ class Weights: assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim) + return self.get_partial_sharded(tensor_name, dim, cast=cast) def get_packed_sharded( - self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]] + self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]], cast=True ) -> torch.Tensor: """ Get a shard from a tensor that packs multiple tensors. @@ -363,12 +386,16 @@ class Weights: tensor = tensor.to(device=self.device) # Avoid casting quantizer dtypes. - if tensor.dtype not in [ - torch.float8_e4m3fn, - torch.int16, - torch.int32, - torch.int64, - ]: + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and cast + ): tensor = tensor.to(dtype=self.dtype) return tensor