loading quantization config the proper way

This commit is contained in:
Chris 2023-08-27 16:48:30 +02:00
parent cf178a278a
commit 694a535033

View File

@ -3,7 +3,7 @@ import inspect
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
@ -474,14 +474,20 @@ class CausalLM(Model):
truncation_side="left",
trust_remote_code=trust_remote_code,
)
should_quantize = quantize == "bitsandbytes"
if(should_quantize):
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_4bit=quantize == "bitsandbytes",
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
load_in_4bit= False if not should_quantize else None,
quantization_config = quantization_config if should_quantize else None,
trust_remote_code=trust_remote_code,
)
## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.