mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
loading quantization config the proper way
This commit is contained in:
parent
cf178a278a
commit
694a535033
@ -3,7 +3,7 @@ import inspect
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
@ -474,14 +474,20 @@ class CausalLM(Model):
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
should_quantize = quantize == "bitsandbytes"
|
||||
if(should_quantize):
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
load_in_4bit=quantize == "bitsandbytes",
|
||||
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
load_in_4bit= False if not should_quantize else None,
|
||||
quantization_config = quantization_config if should_quantize else None,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
|
||||
|
Loading…
Reference in New Issue
Block a user