fix auto conversion

This commit is contained in:
OlivierDehaene 2024-07-20 09:16:42 +02:00
parent 6a93a24f3f
commit 5789139c68
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
3 changed files with 15 additions and 12 deletions

View File

@ -67,8 +67,9 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(self, activation_scale_ub: Optional[float]):
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8
def get_weights_col_packed(
self,
@ -91,6 +92,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
@ -111,6 +114,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
@ -125,6 +130,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
@ -186,8 +193,6 @@ class Fp8Linear(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 kernels")
qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound
)

View File

@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
attention,
reshape_and_cache,
)
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -42,16 +41,15 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.fp8 import Fp8Weight
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
)
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
if SYSTEM == "rocm":
try:
@ -113,12 +111,12 @@ def load_attention(config, prefix: str, weights, layer_id):
@contextmanager
def no_fp8(weights: Weights):
"""De-activate fp8 auto conversion for the duration of this context manager"""
weights_loader = weights.weights_loader
if (
isinstance(weights_loader, DefaultWeightsLoader)
and weights_loader.weight_class is Fp8Weight
):
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:
weights_loader = HybridFP8UnquantLoader(
weights_loader.activation_scale_ub, to_fp8=False
)
with weights.use_loader(weights_loader):
yield

View File

@ -172,6 +172,6 @@ def get_loader(
if isinstance(quantizer_config, _FP8QuantizerConfig):
activation_scale_ub = quantizer_config.activation_scale_ub
return HybridFP8UnquantLoader(activation_scale_ub)
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
else:
raise ValueError(f"Unknown quantization method: {quantize}")