mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix auto conversion
This commit is contained in:
parent
6a93a24f3f
commit
5789139c68
@ -67,8 +67,9 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
||||
class HybridFP8UnquantLoader(WeightsLoader):
|
||||
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||
|
||||
def __init__(self, activation_scale_ub: Optional[float]):
|
||||
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
|
||||
self.activation_scale_ub = activation_scale_ub
|
||||
self.to_fp8 = to_fp8
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
@ -91,6 +92,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
@ -111,6 +114,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
@ -125,6 +130,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
@ -186,8 +193,6 @@ class Fp8Linear(torch.nn.Module):
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if FBGEMM_MM_AVAILABLE:
|
||||
log_once(logger.info, "Using FBGEMM fp8 kernels")
|
||||
|
||||
qinput, scale = fp8_quantize(
|
||||
input, scale_upper_bound=self.scale_upper_bound
|
||||
)
|
||||
|
@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -42,16 +41,15 @@ from text_generation_server.layers import (
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import Fp8Weight
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
@ -113,12 +111,12 @@ def load_attention(config, prefix: str, weights, layer_id):
|
||||
|
||||
@contextmanager
|
||||
def no_fp8(weights: Weights):
|
||||
"""De-activate fp8 auto conversion for the duration of this context manager"""
|
||||
weights_loader = weights.weights_loader
|
||||
if (
|
||||
isinstance(weights_loader, DefaultWeightsLoader)
|
||||
and weights_loader.weight_class is Fp8Weight
|
||||
):
|
||||
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
|
||||
if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:
|
||||
weights_loader = HybridFP8UnquantLoader(
|
||||
weights_loader.activation_scale_ub, to_fp8=False
|
||||
)
|
||||
|
||||
with weights.use_loader(weights_loader):
|
||||
yield
|
||||
|
@ -172,6 +172,6 @@ def get_loader(
|
||||
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
||||
activation_scale_ub = quantizer_config.activation_scale_ub
|
||||
|
||||
return HybridFP8UnquantLoader(activation_scale_ub)
|
||||
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||
|
Loading…
Reference in New Issue
Block a user