From 7402a355dcbf9ffe7a0b2a788f2062aa9e0a3ed5 Mon Sep 17 00:00:00 2001 From: momonga <146910567+mmngays@users.noreply.github.com> Date: Thu, 19 Oct 2023 17:42:03 +0900 Subject: [PATCH] Fix calling cuda() on load_in_8bit (#1153) This PR addresses an issue where calling `model = model.cuda()` would throw a ValueError when `quantize` is set to "bitsandbytes". ``` > File "/opt/conda/lib/python3.9/site-packages/text_generation_server/server.py", line 147, in serve_inner model = get_model( File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/__init__.py", line 295, in get_model return CausalLM( File "/opt/conda/lib/python3.9/site-packages/text_generation_server/models/causal_lm.py", line 515, in __init__ model = model.cuda() File "/opt/conda/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1998, in cuda raise ValueError( ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`. ``` Co-authored-by: mmnga --- server/text_generation_server/models/causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index fccfb0f8..8056a8ec 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -511,7 +511,7 @@ class CausalLM(Model): load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) - if torch.cuda.is_available() and torch.cuda.device_count() == 1: + if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes": model = model.cuda() if tokenizer.pad_token_id is None: