mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix auto conversion
This commit is contained in:
parent
6a93a24f3f
commit
5789139c68
@ -67,8 +67,9 @@ def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn):
|
|||||||
class HybridFP8UnquantLoader(WeightsLoader):
|
class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||||
|
|
||||||
def __init__(self, activation_scale_ub: Optional[float]):
|
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
|
||||||
self.activation_scale_ub = activation_scale_ub
|
self.activation_scale_ub = activation_scale_ub
|
||||||
|
self.to_fp8 = to_fp8
|
||||||
|
|
||||||
def get_weights_col_packed(
|
def get_weights_col_packed(
|
||||||
self,
|
self,
|
||||||
@ -91,6 +92,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
return UnquantizedWeight(w)
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
@ -111,6 +114,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
return UnquantizedWeight(w)
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
@ -125,6 +130,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
activation_scale_ub=self.activation_scale_ub,
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
dtype=weights.dtype,
|
dtype=weights.dtype,
|
||||||
)
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
return UnquantizedWeight(w)
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
@ -186,8 +193,6 @@ class Fp8Linear(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if FBGEMM_MM_AVAILABLE:
|
||||||
log_once(logger.info, "Using FBGEMM fp8 kernels")
|
|
||||||
|
|
||||||
qinput, scale = fp8_quantize(
|
qinput, scale = fp8_quantize(
|
||||||
input, scale_upper_bound=self.scale_upper_bound
|
input, scale_upper_bound=self.scale_upper_bound
|
||||||
)
|
)
|
||||||
|
@ -33,7 +33,6 @@ from text_generation_server.layers.attention import (
|
|||||||
attention,
|
attention,
|
||||||
reshape_and_cache,
|
reshape_and_cache,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import FLASH_DECODING
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -42,16 +41,15 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelMultiAdapterLinear,
|
TensorParallelMultiAdapterLinear,
|
||||||
TensorParallelAdapterRowLinear,
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.fp8 import Fp8Weight
|
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import (
|
from text_generation_server.utils.weights import (
|
||||||
DefaultWeightsLoader,
|
|
||||||
UnquantizedWeight,
|
UnquantizedWeight,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
try:
|
try:
|
||||||
@ -113,12 +111,12 @@ def load_attention(config, prefix: str, weights, layer_id):
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def no_fp8(weights: Weights):
|
def no_fp8(weights: Weights):
|
||||||
|
"""De-activate fp8 auto conversion for the duration of this context manager"""
|
||||||
weights_loader = weights.weights_loader
|
weights_loader = weights.weights_loader
|
||||||
if (
|
if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:
|
||||||
isinstance(weights_loader, DefaultWeightsLoader)
|
weights_loader = HybridFP8UnquantLoader(
|
||||||
and weights_loader.weight_class is Fp8Weight
|
weights_loader.activation_scale_ub, to_fp8=False
|
||||||
):
|
)
|
||||||
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
|
|
||||||
|
|
||||||
with weights.use_loader(weights_loader):
|
with weights.use_loader(weights_loader):
|
||||||
yield
|
yield
|
||||||
|
@ -172,6 +172,6 @@ def get_loader(
|
|||||||
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
if isinstance(quantizer_config, _FP8QuantizerConfig):
|
||||||
activation_scale_ub = quantizer_config.activation_scale_ub
|
activation_scale_ub = quantizer_config.activation_scale_ub
|
||||||
|
|
||||||
return HybridFP8UnquantLoader(activation_scale_ub)
|
return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown quantization method: {quantize}")
|
raise ValueError(f"Unknown quantization method: {quantize}")
|
||||||
|
Loading…
Reference in New Issue
Block a user