diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 0a44aafc..8f8adad9 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,10 +1,12 @@ import torch import inspect - +import json + from dataclasses import dataclass from opentelemetry import trace from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig from typing import Optional, Tuple, List, Type, Dict +from peft import PeftModelForCausalLM, get_peft_config, PeftConfig from text_generation_server.models import Model from text_generation_server.models.types import ( @@ -476,11 +478,15 @@ class CausalLM(Model): ) should_quantize = quantize == "bitsandbytes" - if(should_quantize): + if should_quantize: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16 ) + + has_peft_model = True + peft_model_id_or_path = "/mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-instruct-pl-lora_adapter_model" + model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, @@ -490,6 +496,15 @@ class CausalLM(Model): quantization_config = quantization_config if should_quantize else None, trust_remote_code=trust_remote_code, ) + if has_peft_model: + with open(f'{peft_model_id_or_path}/adapter_config.json') as config_file: + config = json.load(config_file) + peft_config = get_peft_config(config) + model = PeftModelForCausalLM(model, peft_config) + ## Llama does not have a load_adapter method - we need to think about hot swapping here and implement this for Llama + # model.load_adapter(peft_model_id_or_path) + # model.enable_adapters() + ## ValueError: Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`. # if torch.cuda.is_available() and torch.cuda.device_count() == 1: # model = model.cuda()