fix(server): fix quantization

This commit is contained in:
OlivierDehaene 2023-05-30 13:56:03 +02:00
parent 49a6c8c1b2
commit bf7f1d5434
4 changed files with 24 additions and 32 deletions

View File

@ -246,9 +246,7 @@ class BLOOMSharded(BLOOM):
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq": elif quantize == "gptq":
raise NotImplementedError( raise NotImplementedError("`gptq` is not implemented for now")
"`gptq` is not implemented for now"
)
elif quantize is None: elif quantize is None:
tensor = tensor.to(device) tensor = tensor.to(device)
else: else:

View File

@ -365,9 +365,7 @@ class GalacticaSharded(Galactica):
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq": elif quantize == "gptq":
raise NotImplementedError( raise NotImplementedError("`gptq` is not implemented for now")
"`gptq` is not implemented for now"
)
elif quantize is None: elif quantize is None:
tensor = tensor.to(device) tensor = tensor.to(device)
else: else:

View File

@ -211,9 +211,7 @@ class GPTNeoxSharded(CausalLM):
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq": elif quantize == "gptq":
raise NotImplementedError( raise NotImplementedError("`gptq` is not implemented for now")
"`gptq` is not implemented for now"
)
elif quantize is None: elif quantize is None:
tensor = tensor.to(device) tensor = tensor.to(device)
else: else:

View File

@ -224,10 +224,8 @@ class T5Sharded(Seq2SeqLM):
module.linear = replace_linear(state) module.linear = replace_linear(state)
elif quantize == "gptq" and not module_name.endswith("wo"): elif quantize == "gptq" and not module_name.endswith("wo"):
raise NotImplementedError( raise NotImplementedError("`gptq` is not implemented for now")
"`gptq` is not implemented for now" elif quantize is None or module_name.endswith("wo"):
)
elif quantize is None:
tensor = tensor.to(device) tensor = tensor.to(device)
else: else:
raise ValueError(f"Unexpected quantize `{quantize}`") raise ValueError(f"Unexpected quantize `{quantize}`")