mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
missing get_weights implementation
This commit is contained in:
parent
b9410c3edf
commit
c9e8b68426
@ -71,6 +71,23 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
|||||||
self.activation_scale_ub = activation_scale_ub
|
self.activation_scale_ub = activation_scale_ub
|
||||||
self.to_fp8 = to_fp8
|
self.to_fp8 = to_fp8
|
||||||
|
|
||||||
|
def get_weights(self, weights: "Weights", prefix: str):
|
||||||
|
w = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
|
||||||
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
# FP8 branch
|
||||||
|
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||||
|
return Fp8Weight(
|
||||||
|
weight=w,
|
||||||
|
weight_scale=scale,
|
||||||
|
activation_scale_ub=self.activation_scale_ub,
|
||||||
|
dtype=weights.dtype,
|
||||||
|
)
|
||||||
|
if self.to_fp8:
|
||||||
|
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||||
|
|
||||||
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
def get_weights_col_packed(
|
def get_weights_col_packed(
|
||||||
self,
|
self,
|
||||||
weights: Weights,
|
weights: Weights,
|
||||||
|
Loading…
Reference in New Issue
Block a user