diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 90db59fc..45d7cd4c 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -245,6 +245,8 @@ class BLOOMSharded(BLOOM): return linear module.linear = replace_linear(state) + else: + tensor = tensor.to(device) elif quantize == "gptq": raise NotImplementedError("`gptq` is not implemented for now") elif quantize is None: diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 954421f0..37ccc398 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -364,6 +364,8 @@ class GalacticaSharded(Galactica): return linear module.linear = replace_linear(state) + else: + tensor = tensor.to(device) elif quantize == "gptq": raise NotImplementedError("`gptq` is not implemented for now") elif quantize is None: diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index c0e8170d..5ab8a624 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -210,6 +210,8 @@ class GPTNeoxSharded(CausalLM): return linear module.linear = replace_linear(state) + else: + tensor = tensor.to(device) elif quantize == "gptq": raise NotImplementedError("`gptq` is not implemented for now") elif quantize is None: diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index bccce5b3..9cc4d5e1 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -166,7 +166,7 @@ class OPTSharded(OPT): tensor = tensor.contiguous().to(dtype) - if quantize: + if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( "bitsandbytes is not available on your machine either because it is not installed " @@ -216,9 +216,14 @@ class OPTSharded(OPT): return linear module.linear = replace_linear(state) - else: tensor = tensor.to(device) + elif quantize == "gptq": + raise NotImplementedError("`gptq` is not implemented for now") + elif quantize is None: + tensor = tensor.to(device) + else: + raise ValueError(f"Unexpected quantize `{quantize}`") module._parameters[param_name] = tensor if name == "model.decoder.embed_tokens.weight": diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 32dcb806..d12b89d2 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -222,7 +222,8 @@ class T5Sharded(Seq2SeqLM): return linear module.linear = replace_linear(state) - + else: + tensor = tensor.to(device) elif quantize == "gptq" and not module_name.endswith("wo"): raise NotImplementedError("`gptq` is not implemented for now") elif quantize is None or module_name.endswith("wo"):