diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index bffbfa64..0a44aafc 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -3,7 +3,7 @@ import inspect from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model @@ -474,14 +474,20 @@ class CausalLM(Model): truncation_side="left", trust_remote_code=trust_remote_code, ) + + should_quantize = quantize == "bitsandbytes" + if(should_quantize): + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16 + ) model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, - device_map="auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None, - load_in_4bit=quantize == "bitsandbytes", + device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, + load_in_4bit= False if not should_quantize else None, + quantization_config = quantization_config if should_quantize else None, trust_remote_code=trust_remote_code, ) ## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.