mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
fix(server): fix quantization
This commit is contained in:
parent
49a6c8c1b2
commit
bf7f1d5434
@ -245,14 +245,12 @@ class BLOOMSharded(BLOOM):
|
|||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
"`gptq` is not implemented for now"
|
elif quantize is None:
|
||||||
)
|
tensor = tensor.to(device)
|
||||||
elif quantize is None:
|
else:
|
||||||
tensor = tensor.to(device)
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
|
||||||
|
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
if name == "word_embeddings.weight":
|
if name == "word_embeddings.weight":
|
||||||
|
@ -364,14 +364,12 @@ class GalacticaSharded(Galactica):
|
|||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
"`gptq` is not implemented for now"
|
elif quantize is None:
|
||||||
)
|
tensor = tensor.to(device)
|
||||||
elif quantize is None:
|
else:
|
||||||
tensor = tensor.to(device)
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
|
||||||
|
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
if name == "model.decoder.embed_tokens.weight":
|
if name == "model.decoder.embed_tokens.weight":
|
||||||
|
@ -210,14 +210,12 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
return linear
|
return linear
|
||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
elif quantize == "gptq":
|
elif quantize == "gptq":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
"`gptq` is not implemented for now"
|
elif quantize is None:
|
||||||
)
|
tensor = tensor.to(device)
|
||||||
elif quantize is None:
|
else:
|
||||||
tensor = tensor.to(device)
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
|
||||||
|
|
||||||
if current_parameter_tensor is not None:
|
if current_parameter_tensor is not None:
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
|
@ -223,14 +223,12 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
|
|
||||||
module.linear = replace_linear(state)
|
module.linear = replace_linear(state)
|
||||||
|
|
||||||
elif quantize == "gptq" and not module_name.endswith("wo"):
|
elif quantize == "gptq" and not module_name.endswith("wo"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("`gptq` is not implemented for now")
|
||||||
"`gptq` is not implemented for now"
|
elif quantize is None or module_name.endswith("wo"):
|
||||||
)
|
tensor = tensor.to(device)
|
||||||
elif quantize is None:
|
else:
|
||||||
tensor = tensor.to(device)
|
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
|
||||||
|
|
||||||
if current_parameter_tensor is not None:
|
if current_parameter_tensor is not None:
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user