mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
missing get_weights implementation
This commit is contained in:
parent
b9410c3edf
commit
c9e8b68426
@ -71,6 +71,23 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
self.activation_scale_ub = activation_scale_ub
|
||||
self.to_fp8 = to_fp8
|
||||
|
||||
def get_weights(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_tensor(f"{prefix}.weight")
|
||||
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FP8 branch
|
||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
|
||||
return UnquantizedWeight(w)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
|
Loading…
Reference in New Issue
Block a user