From ee4174b6c726fcfc520786f58d3723fe0a4e89ca Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 18 Jul 2024 15:05:04 +0200 Subject: [PATCH] allow loading fp8 weights directly --- server/text_generation_server/layers/fp8.py | 41 +++++++++----- .../text_generation_server/layers/marlin.py | 13 ++++- .../text_generation_server/utils/weights.py | 55 +++++++++++++++++-- 3 files changed, 88 insertions(+), 21 deletions(-) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index f7512beb..b4e7eb78 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,6 +1,7 @@ +import torch + from dataclasses import dataclass -import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.weights import Weight @@ -64,27 +65,41 @@ class Fp8Weight(Weight): class Fp8Linear(torch.nn.Module): def __init__( self, - weight, + qweight, + scale, + scale_upper_bound, bias, + dtype, ) -> None: super().__init__() - self.dtype = weight.dtype - self.qweight, self.scale = fp8_quantize(weight) + self.dtype = dtype + self.qweight = qweight + self.scale = scale + self.scale_upper_bound = scale_upper_bound self.bias = bias if bias is not None else None + @classmethod + def from_unquant(cls, weight, bias, dtype): + qweight, scale = fp8_quantize(weight) + return cls( + qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + ) + + @classmethod + def from_fp8(cls, weight, bias, dtype): + return cls( + qweight=weight.weight, + scale=weight.weight_scale, + scale_upper_bound=weight.input_scale, + bias=bias, + dtype=dtype, + ) + def forward(self, input: torch.Tensor) -> torch.Tensor: if HAS_FBGEMM: - global default_activation_scale_upper_bound - - device = input.device - if default_activation_scale_upper_bound.device != device: - default_activation_scale_upper_bound = ( - default_activation_scale_upper_bound.to(device) - ) - qinput, scale = fp8_quantize( - input, scale_upper_bound=default_activation_scale_upper_bound + input, scale_upper_bound=self.scale_upper_bound ) y = torch.ops.fbgemm.f8f8bf16_rowwise( diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a913ff57..542d9a35 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -503,7 +503,8 @@ class GPTQMarlinFP8Linear(nn.Module): def __init__( self, - weight: torch.Tensor, + qweight: torch.Tensor, + scale: torch.Tensor, bias: Optional[torch.Tensor], ) -> None: super().__init__() @@ -513,7 +514,6 @@ class GPTQMarlinFP8Linear(nn.Module): log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") - qweight, scale = fp8_quantize(weight) scale = scale.to(torch.float16) qweight, scales = repack_fp8_for_marlin(qweight, scale) @@ -529,6 +529,15 @@ class GPTQMarlinFP8Linear(nn.Module): out_features // 64 * 16, dtype=torch.int, device=qweight.device ) + @classmethod + def from_unquant(cls, weight, bias): + qweight, scale = fp8_quantize(weight) + return cls(qweight=qweight, scale=scale, bias=bias) + + @classmethod + def from_fp8(cls, weight, bias): + return cls(qweight=weight.weight, scale=weight.weight_scale, bias=bias) + def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 91592df0..9c2d6cfe 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,3 +1,5 @@ +import torch + from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass @@ -5,9 +7,9 @@ from enum import Enum, auto from pathlib import Path from typing import Dict, List, Optional, Union -import torch -from safetensors import safe_open from text_generation_server.utils.import_utils import SYSTEM +from safetensors import safe_open +from dataclasses import dataclass class WeightsLoader(ABC): @@ -127,15 +129,45 @@ class DefaultWeightsLoader(WeightsLoader): ), ) + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes + ) + input_scale = weights.get_tensor(f"{prefix}.input_scale") + return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale) + return w + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] return self.weight_class(torch.cat(w, dim=dim)) + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + w = torch.cat(w, dim=dim) + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = [weights.get_sharded(f"{p}.weight_scale", dim=0) for p in prefixes] + scale = torch.cat(scale, dim=0) + input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale") + return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale) + return w + def get_weights_row(self, weights: "Weights", prefix: str): return self.weight_class( weights.get_sharded(f"{prefix}.weight", dim=1), ) + w = weights.get_sharded(f"{prefix}.weight", dim=1) + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0) + input_scale = weights.get_tensor(f"{prefix}.input_scale") + return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale) + return w + class Weights: def __init__( @@ -214,8 +246,13 @@ class Weights: tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert # u4 which are disguised as int32. Exl2 uses int16 - # as well. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + # as well. FP8 uses torch.float8_e4m3fn + if tensor.dtype not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ]: tensor = tensor.to(dtype=self.dtype) if to_device: tensor = tensor.to(device=self.device) @@ -241,7 +278,8 @@ class Weights: raise NotImplementedError("Let's make that generic when needed") # Special case for gptq which shouldn't convert # u4 which are disguised as int32. exl2 uses int16. - if tensor.dtype not in (torch.int16, torch.int32): + # FP8 uses torch.float8_e4m3fn. + if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32): tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor @@ -304,7 +342,12 @@ class Weights: tensor = tensor.to(device=self.device) # Avoid casting quantizer dtypes. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + if tensor.dtype not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ]: tensor = tensor.to(dtype=self.dtype) return tensor