mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
allow loading fp8 weights directly
This commit is contained in:
parent
27084bbfd3
commit
ee4174b6c7
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import Weight
|
||||
|
||||
@ -64,27 +65,41 @@ class Fp8Weight(Weight):
|
||||
class Fp8Linear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
qweight,
|
||||
scale,
|
||||
scale_upper_bound,
|
||||
bias,
|
||||
dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dtype = weight.dtype
|
||||
self.qweight, self.scale = fp8_quantize(weight)
|
||||
self.dtype = dtype
|
||||
self.qweight = qweight
|
||||
self.scale = scale
|
||||
self.scale_upper_bound = scale_upper_bound
|
||||
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if HAS_FBGEMM:
|
||||
global default_activation_scale_upper_bound
|
||||
|
||||
device = input.device
|
||||
if default_activation_scale_upper_bound.device != device:
|
||||
default_activation_scale_upper_bound = (
|
||||
default_activation_scale_upper_bound.to(device)
|
||||
@classmethod
|
||||
def from_unquant(cls, weight, bias, dtype):
|
||||
qweight, scale = fp8_quantize(weight)
|
||||
return cls(
|
||||
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_fp8(cls, weight, bias, dtype):
|
||||
return cls(
|
||||
qweight=weight.weight,
|
||||
scale=weight.weight_scale,
|
||||
scale_upper_bound=weight.input_scale,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if HAS_FBGEMM:
|
||||
qinput, scale = fp8_quantize(
|
||||
input, scale_upper_bound=default_activation_scale_upper_bound
|
||||
input, scale_upper_bound=self.scale_upper_bound
|
||||
)
|
||||
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
|
@ -503,7 +503,8 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: torch.Tensor,
|
||||
qweight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -513,7 +514,6 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||
|
||||
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
||||
|
||||
qweight, scale = fp8_quantize(weight)
|
||||
scale = scale.to(torch.float16)
|
||||
qweight, scales = repack_fp8_for_marlin(qweight, scale)
|
||||
|
||||
@ -529,6 +529,15 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_unquant(cls, weight, bias):
|
||||
qweight, scale = fp8_quantize(weight)
|
||||
return cls(qweight=qweight, scale=scale, bias=bias)
|
||||
|
||||
@classmethod
|
||||
def from_fp8(cls, weight, bias):
|
||||
return cls(qweight=weight.weight, scale=weight.weight_scale, bias=bias)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
import torch
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
@ -5,9 +7,9 @@ from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from safetensors import safe_open
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class WeightsLoader(ABC):
|
||||
@ -127,15 +129,45 @@ class DefaultWeightsLoader(WeightsLoader):
|
||||
),
|
||||
)
|
||||
|
||||
w = weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
input_scale = weights.get_tensor(f"{prefix}.input_scale")
|
||||
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||
return w
|
||||
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
return self.weight_class(torch.cat(w, dim=dim))
|
||||
|
||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
w = torch.cat(w, dim=dim)
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
scale = [weights.get_sharded(f"{p}.weight_scale", dim=0) for p in prefixes]
|
||||
scale = torch.cat(scale, dim=0)
|
||||
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale")
|
||||
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||
return w
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
return self.weight_class(
|
||||
weights.get_sharded(f"{prefix}.weight", dim=1),
|
||||
)
|
||||
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0)
|
||||
input_scale = weights.get_tensor(f"{prefix}.input_scale")
|
||||
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||
return w
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
@ -214,8 +246,13 @@ class Weights:
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32. Exl2 uses int16
|
||||
# as well.
|
||||
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
||||
# as well. FP8 uses torch.float8_e4m3fn
|
||||
if tensor.dtype not in [
|
||||
torch.float8_e4m3fn,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]:
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
if to_device:
|
||||
tensor = tensor.to(device=self.device)
|
||||
@ -241,7 +278,8 @@ class Weights:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
# Special case for gptq which shouldn't convert
|
||||
# u4 which are disguised as int32. exl2 uses int16.
|
||||
if tensor.dtype not in (torch.int16, torch.int32):
|
||||
# FP8 uses torch.float8_e4m3fn.
|
||||
if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32):
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
@ -304,7 +342,12 @@ class Weights:
|
||||
tensor = tensor.to(device=self.device)
|
||||
|
||||
# Avoid casting quantizer dtypes.
|
||||
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
||||
if tensor.dtype not in [
|
||||
torch.float8_e4m3fn,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
]:
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
|
||||
return tensor
|
||||
|
Loading…
Reference in New Issue
Block a user