allow loading fp8 weights directly

This commit is contained in:
OlivierDehaene 2024-07-18 15:05:04 +02:00
parent 27084bbfd3
commit ee4174b6c7
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
3 changed files with 88 additions and 21 deletions

View File

@ -1,6 +1,7 @@
import torch
from dataclasses import dataclass
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weight
@ -64,27 +65,41 @@ class Fp8Weight(Weight):
class Fp8Linear(torch.nn.Module):
def __init__(
self,
weight,
qweight,
scale,
scale_upper_bound,
bias,
dtype,
) -> None:
super().__init__()
self.dtype = weight.dtype
self.qweight, self.scale = fp8_quantize(weight)
self.dtype = dtype
self.qweight = qweight
self.scale = scale
self.scale_upper_bound = scale_upper_bound
self.bias = bias if bias is not None else None
def forward(self, input: torch.Tensor) -> torch.Tensor:
if HAS_FBGEMM:
global default_activation_scale_upper_bound
device = input.device
if default_activation_scale_upper_bound.device != device:
default_activation_scale_upper_bound = (
default_activation_scale_upper_bound.to(device)
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight)
return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
)
@classmethod
def from_fp8(cls, weight, bias, dtype):
return cls(
qweight=weight.weight,
scale=weight.weight_scale,
scale_upper_bound=weight.input_scale,
bias=bias,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if HAS_FBGEMM:
qinput, scale = fp8_quantize(
input, scale_upper_bound=default_activation_scale_upper_bound
input, scale_upper_bound=self.scale_upper_bound
)
y = torch.ops.fbgemm.f8f8bf16_rowwise(

View File

@ -503,7 +503,8 @@ class GPTQMarlinFP8Linear(nn.Module):
def __init__(
self,
weight: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor],
) -> None:
super().__init__()
@ -513,7 +514,6 @@ class GPTQMarlinFP8Linear(nn.Module):
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
qweight, scale = fp8_quantize(weight)
scale = scale.to(torch.float16)
qweight, scales = repack_fp8_for_marlin(qweight, scale)
@ -529,6 +529,15 @@ class GPTQMarlinFP8Linear(nn.Module):
out_features // 64 * 16, dtype=torch.int, device=qweight.device
)
@classmethod
def from_unquant(cls, weight, bias):
qweight, scale = fp8_quantize(weight)
return cls(qweight=qweight, scale=scale, bias=bias)
@classmethod
def from_fp8(cls, weight, bias):
return cls(qweight=weight.weight, scale=weight.weight_scale, bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None

View File

@ -1,3 +1,5 @@
import torch
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
@ -5,9 +7,9 @@ from enum import Enum, auto
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
from safetensors import safe_open
from text_generation_server.utils.import_utils import SYSTEM
from safetensors import safe_open
from dataclasses import dataclass
class WeightsLoader(ABC):
@ -127,15 +129,45 @@ class DefaultWeightsLoader(WeightsLoader):
),
)
w = weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
)
input_scale = weights.get_tensor(f"{prefix}.input_scale")
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
return w
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
return self.weight_class(torch.cat(w, dim=dim))
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
w = torch.cat(w, dim=dim)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = [weights.get_sharded(f"{p}.weight_scale", dim=0) for p in prefixes]
scale = torch.cat(scale, dim=0)
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale")
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
return w
def get_weights_row(self, weights: "Weights", prefix: str):
return self.weight_class(
weights.get_sharded(f"{prefix}.weight", dim=1),
)
w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch
if w.dtype == torch.float8_e4m3fn:
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0)
input_scale = weights.get_tensor(f"{prefix}.input_scale")
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
return w
class Weights:
def __init__(
@ -214,8 +246,13 @@ class Weights:
tensor = f.get_tensor(tensor_name)
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. Exl2 uses int16
# as well.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
# as well. FP8 uses torch.float8_e4m3fn
if tensor.dtype not in [
torch.float8_e4m3fn,
torch.int16,
torch.int32,
torch.int64,
]:
tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device)
@ -241,7 +278,8 @@ class Weights:
raise NotImplementedError("Let's make that generic when needed")
# Special case for gptq which shouldn't convert
# u4 which are disguised as int32. exl2 uses int16.
if tensor.dtype not in (torch.int16, torch.int32):
# FP8 uses torch.float8_e4m3fn.
if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32):
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
@ -304,7 +342,12 @@ class Weights:
tensor = tensor.to(device=self.device)
# Avoid casting quantizer dtypes.
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
if tensor.dtype not in [
torch.float8_e4m3fn,
torch.int16,
torch.int32,
torch.int64,
]:
tensor = tensor.to(dtype=self.dtype)
return tensor