fix(server): fix quantization

This commit is contained in:
OlivierDehaene 2023-05-30 13:56:03 +02:00
parent 49a6c8c1b2
commit bf7f1d5434
4 changed files with 24 additions and 32 deletions

View File

@ -245,14 +245,12 @@ class BLOOMSharded(BLOOM):
return linear
module.linear = replace_linear(state)
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":

View File

@ -364,14 +364,12 @@ class GalacticaSharded(Galactica):
return linear
module.linear = replace_linear(state)
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight":

View File

@ -210,14 +210,12 @@ class GPTNeoxSharded(CausalLM):
return linear
module.linear = replace_linear(state)
elif quantize == "gptq":
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor

View File

@ -223,14 +223,12 @@ class T5Sharded(Seq2SeqLM):
module.linear = replace_linear(state)
elif quantize == "gptq" and not module_name.endswith("wo"):
raise NotImplementedError(
"`gptq` is not implemented for now"
)
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
elif quantize == "gptq" and not module_name.endswith("wo"):
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None or module_name.endswith("wo"):
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor