loading adapter model ontop

This commit is contained in:
Chris 2023-08-27 17:32:28 +02:00
parent 694a535033
commit aba56c1343

View File

@ -1,10 +1,12 @@
import torch import torch
import inspect import inspect
import json
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from peft import PeftModelForCausalLM, get_peft_config, PeftConfig
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -476,11 +478,15 @@ class CausalLM(Model):
) )
should_quantize = quantize == "bitsandbytes" should_quantize = quantize == "bitsandbytes"
if(should_quantize): if should_quantize:
quantization_config = BitsAndBytesConfig( quantization_config = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16 bnb_4bit_compute_dtype=torch.float16
) )
has_peft_model = True
peft_model_id_or_path = "/mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-instruct-pl-lora_adapter_model"
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
@ -490,6 +496,15 @@ class CausalLM(Model):
quantization_config = quantization_config if should_quantize else None, quantization_config = quantization_config if should_quantize else None,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if has_peft_model:
with open(f'{peft_model_id_or_path}/adapter_config.json') as config_file:
config = json.load(config_file)
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
## Llama does not have a load_adapter method - we need to think about hot swapping here and implement this for Llama
# model.load_adapter(peft_model_id_or_path)
# model.enable_adapters()
## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`. ## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
# if torch.cuda.is_available() and torch.cuda.device_count() == 1: # if torch.cuda.is_available() and torch.cuda.device_count() == 1:
# model = model.cuda() # model = model.cuda()