mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
allow loading fp8 weights directly
This commit is contained in:
parent
27084bbfd3
commit
ee4174b6c7
@ -1,6 +1,7 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils.weights import Weight
|
from text_generation_server.utils.weights import Weight
|
||||||
|
|
||||||
@ -64,27 +65,41 @@ class Fp8Weight(Weight):
|
|||||||
class Fp8Linear(torch.nn.Module):
|
class Fp8Linear(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight,
|
qweight,
|
||||||
|
scale,
|
||||||
|
scale_upper_bound,
|
||||||
bias,
|
bias,
|
||||||
|
dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = weight.dtype
|
self.dtype = dtype
|
||||||
self.qweight, self.scale = fp8_quantize(weight)
|
self.qweight = qweight
|
||||||
|
self.scale = scale
|
||||||
|
self.scale_upper_bound = scale_upper_bound
|
||||||
|
|
||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
|
qweight, scale = fp8_quantize(weight)
|
||||||
|
return cls(
|
||||||
|
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_fp8(cls, weight, bias, dtype):
|
||||||
|
return cls(
|
||||||
|
qweight=weight.weight,
|
||||||
|
scale=weight.weight_scale,
|
||||||
|
scale_upper_bound=weight.input_scale,
|
||||||
|
bias=bias,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if HAS_FBGEMM:
|
if HAS_FBGEMM:
|
||||||
global default_activation_scale_upper_bound
|
|
||||||
|
|
||||||
device = input.device
|
|
||||||
if default_activation_scale_upper_bound.device != device:
|
|
||||||
default_activation_scale_upper_bound = (
|
|
||||||
default_activation_scale_upper_bound.to(device)
|
|
||||||
)
|
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
input, scale_upper_bound=default_activation_scale_upper_bound
|
input, scale_upper_bound=self.scale_upper_bound
|
||||||
)
|
)
|
||||||
|
|
||||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||||
|
@ -503,7 +503,8 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight: torch.Tensor,
|
qweight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor],
|
bias: Optional[torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -513,7 +514,6 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||||||
|
|
||||||
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
||||||
|
|
||||||
qweight, scale = fp8_quantize(weight)
|
|
||||||
scale = scale.to(torch.float16)
|
scale = scale.to(torch.float16)
|
||||||
qweight, scales = repack_fp8_for_marlin(qweight, scale)
|
qweight, scales = repack_fp8_for_marlin(qweight, scale)
|
||||||
|
|
||||||
@ -529,6 +529,15 @@ class GPTQMarlinFP8Linear(nn.Module):
|
|||||||
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unquant(cls, weight, bias):
|
||||||
|
qweight, scale = fp8_quantize(weight)
|
||||||
|
return cls(qweight=qweight, scale=scale, bias=bias)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_fp8(cls, weight, bias):
|
||||||
|
return cls(qweight=weight.weight, scale=weight.weight_scale, bias=bias)
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
assert marlin_kernels is not None
|
assert marlin_kernels is not None
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -5,9 +7,9 @@ from enum import Enum, auto
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors import safe_open
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from safetensors import safe_open
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
class WeightsLoader(ABC):
|
class WeightsLoader(ABC):
|
||||||
@ -127,15 +129,45 @@ class DefaultWeightsLoader(WeightsLoader):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
w = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
scale = weights.get_packed_sharded(
|
||||||
|
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
|
||||||
|
)
|
||||||
|
input_scale = weights.get_tensor(f"{prefix}.input_scale")
|
||||||
|
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||||
|
return w
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
return self.weight_class(torch.cat(w, dim=dim))
|
return self.weight_class(torch.cat(w, dim=dim))
|
||||||
|
|
||||||
|
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
|
w = torch.cat(w, dim=dim)
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
scale = [weights.get_sharded(f"{p}.weight_scale", dim=0) for p in prefixes]
|
||||||
|
scale = torch.cat(scale, dim=0)
|
||||||
|
input_scale = weights.get_tensor(f"{prefixes[0]}.input_scale")
|
||||||
|
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||||
|
return w
|
||||||
|
|
||||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||||
return self.weight_class(
|
return self.weight_class(
|
||||||
weights.get_sharded(f"{prefix}.weight", dim=1),
|
weights.get_sharded(f"{prefix}.weight", dim=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
|
# FP8 branch
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0)
|
||||||
|
input_scale = weights.get_tensor(f"{prefix}.input_scale")
|
||||||
|
return FP8Weight(weight=w, weight_scale=scale, input_scale=input_scale)
|
||||||
|
return w
|
||||||
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -214,8 +246,13 @@ class Weights:
|
|||||||
tensor = f.get_tensor(tensor_name)
|
tensor = f.get_tensor(tensor_name)
|
||||||
# Special case for gptq which shouldn't convert
|
# Special case for gptq which shouldn't convert
|
||||||
# u4 which are disguised as int32. Exl2 uses int16
|
# u4 which are disguised as int32. Exl2 uses int16
|
||||||
# as well.
|
# as well. FP8 uses torch.float8_e4m3fn
|
||||||
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
if tensor.dtype not in [
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
torch.int16,
|
||||||
|
torch.int32,
|
||||||
|
torch.int64,
|
||||||
|
]:
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
if to_device:
|
if to_device:
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
@ -241,7 +278,8 @@ class Weights:
|
|||||||
raise NotImplementedError("Let's make that generic when needed")
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
# Special case for gptq which shouldn't convert
|
# Special case for gptq which shouldn't convert
|
||||||
# u4 which are disguised as int32. exl2 uses int16.
|
# u4 which are disguised as int32. exl2 uses int16.
|
||||||
if tensor.dtype not in (torch.int16, torch.int32):
|
# FP8 uses torch.float8_e4m3fn.
|
||||||
|
if tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
@ -304,7 +342,12 @@ class Weights:
|
|||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
|
|
||||||
# Avoid casting quantizer dtypes.
|
# Avoid casting quantizer dtypes.
|
||||||
if tensor.dtype not in [torch.int16, torch.int32, torch.int64]:
|
if tensor.dtype not in [
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
torch.int16,
|
||||||
|
torch.int32,
|
||||||
|
torch.int64,
|
||||||
|
]:
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
|
|
||||||
return tensor
|
return tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user