From 119918cc0a28cccf9b93a575a2ae7483501a493b Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Sun, 21 Jul 2024 20:56:54 +0200 Subject: [PATCH] fix(server): fix fp8 weight loading --- server/text_generation_server/layers/fp8.py | 8 +++- .../text_generation_server/layers/marlin.py | 26 +++++----- .../text_generation_server/models/__init__.py | 48 ++++++++++++------- .../text_generation_server/utils/weights.py | 13 +++-- 4 files changed, 59 insertions(+), 36 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index cdf16d6b..ed5114ce 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -115,8 +115,12 @@ class HybridFP8UnquantLoader(WeightsLoader): return UnquantizedWeight(w) def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): - w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - w = torch.cat(w, dim=dim) + # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet + w = [ + weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes + ] + # Concat then send to the device + w = torch.cat(w, dim=dim).to(weights.device) # FP8 branch if w.dtype == torch.float8_e4m3fn: diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index 40271c35..a28012da 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -504,7 +504,7 @@ class GPTQMarlinFP8Linear(nn.Module): def __init__( self, qweight: torch.Tensor, - scale: torch.Tensor, + scales: torch.Tensor, bias: Optional[torch.Tensor], ) -> None: super().__init__() @@ -514,8 +514,11 @@ class GPTQMarlinFP8Linear(nn.Module): log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") - scale = scale.to(torch.float16) - qweight, scales = repack_fp8_for_marlin(qweight, scale) + scales = scales.unsqueeze(0) + if scales.shape[1] == 1: + out_features, in_features = qweight.shape + scales = scales.repeat(1, out_features) + qweight, scales = repack_fp8_for_marlin(qweight, scales) in_features = qweight.shape[0] * MARLIN_TILE_SIZE out_features = scales.shape[1] @@ -530,13 +533,13 @@ class GPTQMarlinFP8Linear(nn.Module): ) @classmethod - def from_unquant(cls, weight, bias, _dtype): - qweight, scale = fp8_quantize(weight) - return cls(qweight=qweight, scale=scale, bias=bias) + def from_unquant(cls, weight, bias, dtype): + qweight, scales = fp8_quantize(weight) + return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) @classmethod - def from_fp8(cls, weight, scale, _input_scale, bias, _dtype): - return cls(qweight=weight, scale=scale, bias=bias) + def from_fp8(cls, weight, scale, _input_scale, bias, dtype): + return cls(qweight=weight, scales=scale.to(dtype), bias=bias) def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None @@ -591,7 +594,7 @@ def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: return packed -def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): +def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): """ Repack FP8 tensor for GPTQ-Marlin. """ @@ -608,7 +611,6 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): qweight, perm, in_features, out_features, 8 ) - scales = scale.reshape(1, 1).repeat(1, out_features) scales = permute_scales(scales) return repacked, scales @@ -621,7 +623,7 @@ class MarlinWeight(Weight): Attributes: B (torch.Tensor): int4-quantized weights packed into int32. - s (torch.Tensor): float16 scales. + s (torch.Tensor): bfloat16/float16 scales. """ B: torch.Tensor @@ -629,7 +631,7 @@ class MarlinWeight(Weight): def __post_init__(self): assert self.B.dtype == torch.int32 - assert self.s.dtype == torch.float16 + assert self.s.dtype in [torch.float16, torch.bfloat16] def get_linear(self, bias: torch.Tensor): return MarlinLinear(weight=self, bias=bias) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a43cdfed..aa045ebf 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -306,20 +306,45 @@ def get_model( max_input_tokens: int, ) -> Model: global FLASH_ATTENTION + + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + model_type = config_dict.get("model_type", None) + + quantization_config = config_dict.get("quantization_config", None) + if quantization_config is not None and quantize is None: + method = quantization_config.get("quant_method", None) + if method in {"gptq", "awq", "exl2"}: + log_master(logger.info, f"Auto selecting quantization method {method}") + quantize = method + elif method == "fbgemm_fp8": + log_master(logger.info, "Auto selecting quantization method fp8") + quantize = "fp8" + else: + log_master(logger.warning, f"Unknown quantization method {method}") + if dtype is None: if quantize in ["awq", "exl2", "gptq", "marlin"]: # These quantizers only work with float16 params. dtype = torch.float16 elif quantize == "fp8": - from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE + from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE - if FBGEMM_MM_AVAILABLE: + if FBGEMM_DYN_AVAILABLE: # fbgemm kernels are fp8xfp8->bf16 dtype = torch.bfloat16 else: - # Keep it as default for now and let - # every model resolve their own default dtype. - dtype = None + config_dtype = config_dict.get("torch_dtype", None) + # Only use the config dtype if its one of TGI's supported dtype + if config_dtype == "float16": + dtype = torch.float16 + elif config_dtype == "bfloat16": + dtype = torch.bfloat16 + else: + # Keep it as default for now and let + # every model resolve their own default dtype. + dtype = None elif dtype == "float16": dtype = torch.float16 elif dtype == "bfloat16": @@ -332,11 +357,6 @@ def get_model( else: set_speculate(0) - config_dict, _ = PretrainedConfig.get_config_dict( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - model_type = config_dict.get("model_type", None) - speculator = None if "medusa_num_heads" in config_dict: medusa_model_id = model_id @@ -451,14 +471,6 @@ def get_model( raise RuntimeError( f"Could not determine model type for {model_id} revision {revision}" ) - quantization_config = config_dict.get("quantization_config", None) - if quantization_config is not None and quantize is None: - method = quantization_config.get("quant_method", None) - if method in {"gptq", "awq", "exl2"}: - log_master(logger.info, f"Auto selecting quantization method {method}") - quantize = method - else: - log_master(logger.warning, f"Unknown quantization method {method}") if quantize == "exl2" and sharded: raise RuntimeError( diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 66bb6051..108ced48 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -230,7 +230,9 @@ class Weights: tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True): + def get_partial_sharded( + self, tensor_name: str, dim: int, to_device=True, to_dtype=True + ): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -256,10 +258,11 @@ class Weights: and to_dtype ): tensor = tensor.to(dtype=self.dtype) - tensor = tensor.to(device=self.device) + if to_device: + tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int, to_dtype=True): + def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -268,7 +271,9 @@ class Weights: assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype) + return self.get_partial_sharded( + tensor_name, dim, to_device=to_device, to_dtype=to_dtype + ) def get_packed_sharded( self,