From bf7f1d54345f9e6e091d4b756e8b8381fe767d89 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 30 May 2023 13:56:03 +0200 Subject: [PATCH] fix(server): fix quantization --- server/text_generation_server/models/bloom.py | 14 ++++++-------- server/text_generation_server/models/galactica.py | 14 ++++++-------- server/text_generation_server/models/gpt_neox.py | 14 ++++++-------- server/text_generation_server/models/t5.py | 14 ++++++-------- 4 files changed, 24 insertions(+), 32 deletions(-) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 088a1457..90db59fc 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -245,14 +245,12 @@ class BLOOMSharded(BLOOM): return linear module.linear = replace_linear(state) - elif quantize == "gptq": - raise NotImplementedError( - "`gptq` is not implemented for now" - ) - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: + tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") module._parameters[param_name] = tensor if name == "word_embeddings.weight": diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 0a3f341b..954421f0 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -364,14 +364,12 @@ class GalacticaSharded(Galactica): return linear module.linear = replace_linear(state) - elif quantize == "gptq": - raise NotImplementedError( - "`gptq` is not implemented for now" - ) - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: + tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") module._parameters[param_name] = tensor if name == "model.decoder.embed_tokens.weight": diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index e4a85082..c0e8170d 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -210,14 +210,12 @@ class GPTNeoxSharded(CausalLM): return linear module.linear = replace_linear(state) - elif quantize == "gptq": - raise NotImplementedError( - "`gptq` is not implemented for now" - ) - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: + tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") if current_parameter_tensor is not None: module._parameters[param_name] = tensor diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 17cc50e0..32dcb806 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -223,14 +223,12 @@ class T5Sharded(Seq2SeqLM): module.linear = replace_linear(state) - elif quantize == "gptq" and not module_name.endswith("wo"): - raise NotImplementedError( - "`gptq` is not implemented for now" - ) - elif quantize is None: - tensor = tensor.to(device) - else: - raise ValueError(f"Unexpected quantize `{quantize}`") + elif quantize == "gptq" and not module_name.endswith("wo"): + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None or module_name.endswith("wo"): + tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") if current_parameter_tensor is not None: module._parameters[param_name] = tensor