missing get_weights implementation

This commit is contained in:
OlivierDehaene 2024-07-20 09:56:46 +02:00
parent b9410c3edf
commit c9e8b68426
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C

View File

@ -71,6 +71,23 @@ class HybridFP8UnquantLoader(WeightsLoader):
self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8
def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight")
if w.dtype == torch.float8_e4m3fn:
# FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
)
if self.to_fp8:
return Fp8Weight(weight=w, dtype=weights.dtype)
return UnquantizedWeight(w)
def get_weights_col_packed(
self,
weights: Weights,