mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
loading quantization config the proper way
This commit is contained in:
parent
cf178a278a
commit
694a535033
@ -3,7 +3,7 @@ import inspect
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -474,14 +474,20 @@ class CausalLM(Model):
|
|||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
should_quantize = quantize == "bitsandbytes"
|
||||||
|
if(should_quantize):
|
||||||
|
quantization_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=torch.float16
|
||||||
|
)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto"
|
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
load_in_4bit= False if not should_quantize else None,
|
||||||
else None,
|
quantization_config = quantization_config if should_quantize else None,
|
||||||
load_in_4bit=quantize == "bitsandbytes",
|
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
|
## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
|
||||||
|
Loading…
Reference in New Issue
Block a user