mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
loading adapter model ontop
This commit is contained in:
parent
694a535033
commit
aba56c1343
@ -1,10 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
|
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
from peft import PeftModelForCausalLM, get_peft_config, PeftConfig
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.models.types import (
|
from text_generation_server.models.types import (
|
||||||
@ -476,11 +478,15 @@ class CausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
should_quantize = quantize == "bitsandbytes"
|
should_quantize = quantize == "bitsandbytes"
|
||||||
if(should_quantize):
|
if should_quantize:
|
||||||
quantization_config = BitsAndBytesConfig(
|
quantization_config = BitsAndBytesConfig(
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
bnb_4bit_compute_dtype=torch.float16
|
bnb_4bit_compute_dtype=torch.float16
|
||||||
)
|
)
|
||||||
|
|
||||||
|
has_peft_model = True
|
||||||
|
peft_model_id_or_path = "/mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-instruct-pl-lora_adapter_model"
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
@ -490,6 +496,15 @@ class CausalLM(Model):
|
|||||||
quantization_config = quantization_config if should_quantize else None,
|
quantization_config = quantization_config if should_quantize else None,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
if has_peft_model:
|
||||||
|
with open(f'{peft_model_id_or_path}/adapter_config.json') as config_file:
|
||||||
|
config = json.load(config_file)
|
||||||
|
peft_config = get_peft_config(config)
|
||||||
|
model = PeftModelForCausalLM(model, peft_config)
|
||||||
|
## Llama does not have a load_adapter method - we need to think about hot swapping here and implement this for Llama
|
||||||
|
# model.load_adapter(peft_model_id_or_path)
|
||||||
|
# model.enable_adapters()
|
||||||
|
|
||||||
## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
|
## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
|
||||||
# if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
# if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
# model = model.cuda()
|
# model = model.cuda()
|
||||||
|
Loading…
Reference in New Issue
Block a user