loading adapter model ontop

This commit is contained in:
Chris 2023-08-27 17:32:28 +02:00
parent 694a535033
commit aba56c1343

View File

@ -1,10 +1,12 @@
import torch
import inspect
import json
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
from typing import Optional, Tuple, List, Type, Dict
from peft import PeftModelForCausalLM, get_peft_config, PeftConfig
from text_generation_server.models import Model
from text_generation_server.models.types import (
@ -476,11 +478,15 @@ class CausalLM(Model):
)
should_quantize = quantize == "bitsandbytes"
if(should_quantize):
if should_quantize:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
has_peft_model = True
peft_model_id_or_path = "/mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-instruct-pl-lora_adapter_model"
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
@ -490,6 +496,15 @@ class CausalLM(Model):
quantization_config = quantization_config if should_quantize else None,
trust_remote_code=trust_remote_code,
)
if has_peft_model:
with open(f'{peft_model_id_or_path}/adapter_config.json') as config_file:
config = json.load(config_file)
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
## Llama does not have a load_adapter method - we need to think about hot swapping here and implement this for Llama
# model.load_adapter(peft_model_id_or_path)
# model.enable_adapters()
## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
# if torch.cuda.is_available() and torch.cuda.device_count() == 1:
# model = model.cuda()