(bug) update all has_tensor

This commit is contained in:
Mohit Sharma 2024-10-15 07:51:03 +00:00
parent 7a7cd5f299
commit b2b5024ec8
2 changed files with 4 additions and 4 deletions

View File

@ -392,7 +392,7 @@ class GPTQWeightsLoader(WeightsLoader):
) )
def _get_gptq_params(self, weights: Weights): def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item() self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item() self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False self.desc_act = False
@ -400,7 +400,7 @@ class GPTQWeightsLoader(WeightsLoader):
# before the `gptq_sym` setting tensor was added. # before the `gptq_sym` setting tensor was added.
self.sym = ( self.sym = (
weights.get_tensor("gptq_sym").item() weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym") if weights.has_tensor("gptq_sym")
else False else False
) )
self.quant_method = "gptq" self.quant_method = "gptq"

View File

@ -232,7 +232,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
) )
def _get_gptq_params(self, weights: Weights): def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item() self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item() self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False self.desc_act = False
@ -240,7 +240,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
# before the `gptq_sym` setting tensor was added. # before the `gptq_sym` setting tensor was added.
self.sym = ( self.sym = (
weights.get_tensor("gptq_sym").item() weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym") if weights.has_tensor("gptq_sym")
else False else False
) )
self.quant_method = "gptq" self.quant_method = "gptq"