allow loading fp8 weights directly

This commit is contained in:
OlivierDehaene 2024-07-18 15:05:04 +02:00
parent 27084bbfd3
commit ee4174b6c7
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
3 changed files with 88 additions and 21 deletions

View File

@ -1,6 +1,7 @@
import torch
from dataclasses import dataclass from dataclasses import dataclass
import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weight from text_generation_server.utils.weights import Weight
@ -64,27 +65,41 @@ class Fp8Weight(Weight):
class Fp8Linear(torch.nn.Module): class Fp8Linear(torch.nn.Module):
def __init__( def __init__(
self, self,
weight, qweight,
scale,
scale_upper_bound,
bias, bias,
dtype,
) -> None: ) -> None:
super().__init__() super().__init__()
self.dtype = weight.dtype self.dtype = dtype
self.qweight, self.scale = fp8_quantize(weight) self.qweight = qweight
self.scale = scale
self.scale_upper_bound = scale_upper_bound
self.bias = bias if bias is not None else None self.bias = bias if bias is not None else None
def forward(self, input: torch.Tensor) -> torch.Tensor: @classmethod
if HAS_FBGEMM: def from_unquant(cls, weight, bias, dtype):
global default_activation_scale_upper_bound qweight, scale = fp8_quantize(weight)
return cls(
device = input.device qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
if default_activation_scale_upper_bound.device != device:
default_activation_scale_upper_bound = (
default_activation_scale_upper_bound.to(device)
) )
@classmethod
def from_fp8(cls, weight, bias, dtype):
return cls(
qweight=weight.weight,
scale=weight.weight_scale,
scale_upper_bound=weight.input_scale,
bias=bias,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if HAS_FBGEMM:
qinput, scale = fp8_quantize( qinput, scale = fp8_quantize(
input, scale_upper_bound=default_activation_scale_upper_bound input, scale_upper_bound=self.scale_upper_bound
) )
y = torch.ops.fbgemm.f8f8bf16_rowwise( y = torch.ops.fbgemm.f8f8bf16_rowwise(

View File

@ -503,7 +503,8 @@ class GPTQMarlinFP8Linear(nn.Module):
def __init__( def __init__(
self, self,
weight: torch.Tensor, qweight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> None: ) -> None:
super().__init__() super().__init__()
@ -513,7 +514,6 @@ class GPTQMarlinFP8Linear(nn.Module):
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
qweight, scale = fp8_quantize(weight)
scale = scale.to(torch.float16) scale = scale.to(torch.float16)
qweight, scales = repack_fp8_for_marlin(qweight, scale) qweight, scales = repack_fp8_for_marlin(qweight, scale)
@ -529,6 +529,15 @@ class GPTQMarlinFP8Linear(nn.Module):
out_features // 64 * 16, dtype=torch.int, device=qweight.device out_features // 64 * 16, dtype=torch.int, device=qweight.device
) )
@classmethod
def from_unquant(cls, weight, bias):
qweight, scale = fp8_quantize(weight)
return cls(qweight=qweight, scale=scale, bias=bias)
@classmethod
def from_fp8(cls, weight, bias):
return cls(qweight=weight.weight, scale=weight.weight_scale, bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert marlin_kernels is not None

View File

@ -1,3 +1,5 @@
import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
@ -5,9 +7,9 @@ from enum import Enum, auto
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import torch
from safetensors import safe_open
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from safetensors import safe_open
from dataclasses import dataclass
class WeightsLoader(ABC): class WeightsLoader(ABC):
@ -127,15 +129,45 @@ class DefaultWeightsLoader(WeightsLoader):
), ),
) )
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
)
input_scale = weights.get_tensor(f"{prefix}.input_scale")
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
return w
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
return self.weight_class(torch.cat(w, dim=dim)) return self.weight_class(torch.cat(w, dim=dim))
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
w = torch.cat(w, dim=dim)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = [weights.get_sharded(f"{p}.weight_scale", dim=0) for p in prefixes]
scale = torch.cat(scale, dim=0)
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale")
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
return w
def get_weights_row(self, weights: "Weights", prefix: str): def get_weights_row(self, weights: "Weights", prefix: str):
return self.weight_class( return self.weight_class(
weights.get_sharded(f"{prefix}.weight", dim=1), weights.get_sharded(f"{prefix}.weight", dim=1),
) )
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0)
input_scale = weights.get_tensor(f"{prefix}.input_scale")
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
return w
class Weights: class Weights:
def __init__( def __init__(
@ -214,8 +246,13 @@ class Weights:
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert # Special case for gptq which shouldn't convert
# u4 which are disguised as int32. Exl2 uses int16 # u4 which are disguised as int32. Exl2 uses int16
# as well. # as well. FP8 uses torch.float8_e4m3fn
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: if tensor.dtype not in [
torch.float8_e4m3fn,
torch.int16,
torch.int32,
torch.int64,
]:
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
if to_device: if to_device:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
@ -241,7 +278,8 @@ class Weights:
raise NotImplementedError("Let's make that generic when needed") raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert # Special case for gptq which shouldn't convert
# u4 which are disguised as int32. exl2 uses int16. # u4 which are disguised as int32. exl2 uses int16.
if tensor.dtype not in (torch.int16, torch.int32): # FP8 uses torch.float8_e4m3fn.
if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
@ -304,7 +342,12 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
# Avoid casting quantizer dtypes. # Avoid casting quantizer dtypes.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: if tensor.dtype not in [
torch.float8_e4m3fn,
torch.int16,
torch.int32,
torch.int64,
]:
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
return tensor return tensor