fix(server): fix fp8 weight loading

This commit is contained in:
OlivierDehaene 2024-07-21 20:56:54 +02:00
parent 6aebf44f47
commit 119918cc0a
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
4 changed files with 59 additions and 36 deletions

View File

@ -115,8 +115,12 @@ class HybridFP8UnquantLoader(WeightsLoader):
return UnquantizedWeight(w) return UnquantizedWeight(w)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = torch.cat(w, dim=dim) w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:

View File

@ -504,7 +504,7 @@ class GPTQMarlinFP8Linear(nn.Module):
def __init__( def __init__(
self, self,
qweight: torch.Tensor, qweight: torch.Tensor,
scale: torch.Tensor, scales: torch.Tensor,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> None: ) -> None:
super().__init__() super().__init__()
@ -514,8 +514,11 @@ class GPTQMarlinFP8Linear(nn.Module):
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
scale = scale.to(torch.float16) scales = scales.unsqueeze(0)
qweight, scales = repack_fp8_for_marlin(qweight, scale) if scales.shape[1] == 1:
out_features, in_features = qweight.shape
scales = scales.repeat(1, out_features)
qweight, scales = repack_fp8_for_marlin(qweight, scales)
in_features = qweight.shape[0] * MARLIN_TILE_SIZE in_features = qweight.shape[0] * MARLIN_TILE_SIZE
out_features = scales.shape[1] out_features = scales.shape[1]
@ -530,13 +533,13 @@ class GPTQMarlinFP8Linear(nn.Module):
) )
@classmethod @classmethod
def from_unquant(cls, weight, bias, _dtype): def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight) qweight, scales = fp8_quantize(weight)
return cls(qweight=qweight, scale=scale, bias=bias) return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
@classmethod @classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, _dtype): def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
return cls(qweight=weight, scale=scale, bias=bias) return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert marlin_kernels is not None
@ -591,7 +594,7 @@ def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
return packed return packed
def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
""" """
Repack FP8 tensor for GPTQ-Marlin. Repack FP8 tensor for GPTQ-Marlin.
""" """
@ -608,7 +611,6 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
qweight, perm, in_features, out_features, 8 qweight, perm, in_features, out_features, 8
) )
scales = scale.reshape(1, 1).repeat(1, out_features)
scales = permute_scales(scales) scales = permute_scales(scales)
return repacked, scales return repacked, scales
@ -621,7 +623,7 @@ class MarlinWeight(Weight):
Attributes: Attributes:
B (torch.Tensor): int4-quantized weights packed into int32. B (torch.Tensor): int4-quantized weights packed into int32.
s (torch.Tensor): float16 scales. s (torch.Tensor): bfloat16/float16 scales.
""" """
B: torch.Tensor B: torch.Tensor
@ -629,7 +631,7 @@ class MarlinWeight(Weight):
def __post_init__(self): def __post_init__(self):
assert self.B.dtype == torch.int32 assert self.B.dtype == torch.int32
assert self.s.dtype == torch.float16 assert self.s.dtype in [torch.float16, torch.bfloat16]
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
return MarlinLinear(weight=self, bias=bias) return MarlinLinear(weight=self, bias=bias)

View File

@ -306,16 +306,41 @@ def get_model(
max_input_tokens: int, max_input_tokens: int,
) -> Model: ) -> Model:
global FLASH_ATTENTION global FLASH_ATTENTION
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config_dict.get("model_type", None)
quantization_config = config_dict.get("quantization_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}:
log_master(logger.info, f"Auto selecting quantization method {method}")
quantize = method
elif method == "fbgemm_fp8":
log_master(logger.info, "Auto selecting quantization method fp8")
quantize = "fp8"
else:
log_master(logger.warning, f"Unknown quantization method {method}")
if dtype is None: if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]: if quantize in ["awq", "exl2", "gptq", "marlin"]:
# These quantizers only work with float16 params. # These quantizers only work with float16 params.
dtype = torch.float16 dtype = torch.float16
elif quantize == "fp8": elif quantize == "fp8":
from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
if FBGEMM_MM_AVAILABLE: if FBGEMM_DYN_AVAILABLE:
# fbgemm kernels are fp8xfp8->bf16 # fbgemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16 dtype = torch.bfloat16
else:
config_dtype = config_dict.get("torch_dtype", None)
# Only use the config dtype if its one of TGI's supported dtype
if config_dtype == "float16":
dtype = torch.float16
elif config_dtype == "bfloat16":
dtype = torch.bfloat16
else: else:
# Keep it as default for now and let # Keep it as default for now and let
# every model resolve their own default dtype. # every model resolve their own default dtype.
@ -332,11 +357,6 @@ def get_model(
else: else:
set_speculate(0) set_speculate(0)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config_dict.get("model_type", None)
speculator = None speculator = None
if "medusa_num_heads" in config_dict: if "medusa_num_heads" in config_dict:
medusa_model_id = model_id medusa_model_id = model_id
@ -451,14 +471,6 @@ def get_model(
raise RuntimeError( raise RuntimeError(
f"Could not determine model type for {model_id} revision {revision}" f"Could not determine model type for {model_id} revision {revision}"
) )
quantization_config = config_dict.get("quantization_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}:
log_master(logger.info, f"Auto selecting quantization method {method}")
quantize = method
else:
log_master(logger.warning, f"Unknown quantization method {method}")
if quantize == "exl2" and sharded: if quantize == "exl2" and sharded:
raise RuntimeError( raise RuntimeError(

View File

@ -230,7 +230,9 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True): def get_partial_sharded(
self, tensor_name: str, dim: int, to_device=True, to_dtype=True
):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
@ -256,10 +258,11 @@ class Weights:
and to_dtype and to_dtype
): ):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int, to_dtype=True): def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
@ -268,7 +271,9 @@ class Weights:
assert ( assert (
size % world_size == 0 size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards" ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype) return self.get_partial_sharded(
tensor_name, dim, to_device=to_device, to_dtype=to_dtype
)
def get_packed_sharded( def get_packed_sharded(
self, self,