From 1659b871b655eed41ca96ff3033e7eb82be9b18a Mon Sep 17 00:00:00 2001 From: Chris Date: Sun, 27 Aug 2023 20:56:53 +0200 Subject: [PATCH] load peft from cli --- server/Makefile | 2 +- server/text_generation_server/cli.py | 3 ++- server/text_generation_server/models/__init__.py | 5 +++++ server/text_generation_server/models/causal_lm.py | 10 +++++++--- server/text_generation_server/server.py | 6 ++++-- 5 files changed, 19 insertions(+), 7 deletions(-) diff --git a/server/Makefile b/server/Makefile index adea6b31..01e215e5 100644 --- a/server/Makefile +++ b/server/Makefile @@ -24,6 +24,6 @@ install: gen-server install-torch pip install -e ".[bnb, accelerate]" run-dev: - SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes + SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes --peft-model-path /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-function-calling-adapters-v2 export-requirements: poetry export -o requirements.txt -E bnb -E quantize --without-hashes diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 38e00d2d..647f2eef 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -33,6 +33,7 @@ def serve( dtype: Optional[Dtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", + peft_model_path: Optional[Path] = None, logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, @@ -79,7 +80,7 @@ def serve( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) server.serve( - model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path + model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, peft_model_path ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 932ab32e..e6d8afd5 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -5,6 +5,7 @@ from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto from typing import Optional +from pathlib import Path from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM @@ -75,6 +76,7 @@ def get_model( quantize: Optional[str], dtype: Optional[str], trust_remote_code: bool, + peft_model_path: Optional[Path], ) -> Model: if dtype is None: dtype = torch.float16 @@ -180,6 +182,7 @@ def get_model( quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, + peft_model_path=peft_model_path, ) elif model_type == "llama": @@ -200,6 +203,7 @@ def get_model( quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, + peft_model_path=peft_model_path, ) if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]: @@ -298,6 +302,7 @@ def get_model( quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, + peft_model_path=peft_model_path, ) if "AutoModelForSeq2SeqLM" in auto_map.keys(): return Seq2SeqLM( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 8f8adad9..7effa7fe 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from opentelemetry import trace from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig from typing import Optional, Tuple, List, Type, Dict +from pathlib import Path from peft import PeftModelForCausalLM, get_peft_config, PeftConfig from text_generation_server.models import Model @@ -458,6 +459,7 @@ class CausalLM(Model): quantize: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + peft_model_path: Optional[Path] = None , ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -484,8 +486,7 @@ class CausalLM(Model): bnb_4bit_compute_dtype=torch.float16 ) - has_peft_model = True - peft_model_id_or_path = "/mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-instruct-pl-lora_adapter_model" + has_peft_model = peft_model_path is not None model = AutoModelForCausalLM.from_pretrained( model_id, @@ -497,8 +498,11 @@ class CausalLM(Model): trust_remote_code=trust_remote_code, ) if has_peft_model: - with open(f'{peft_model_id_or_path}/adapter_config.json') as config_file: + with open(f'{peft_model_path}/adapter_config.json') as config_file: config = json.load(config_file) + # patch to a local path + config["base_model_name_or_path"] = model_id + # conver to peft model peft_config = get_peft_config(config) model = PeftModelForCausalLM(model, peft_config) ## Llama does not have a load_adapter method - we need to think about hot swapping here and implement this for Llama diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 28c671ba..0818a161 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -124,6 +124,7 @@ def serve( dtype: Optional[str], trust_remote_code: bool, uds_path: Path, + peft_model_path: Optional[Path] ): async def serve_inner( model_id: str, @@ -132,6 +133,7 @@ def serve( quantize: Optional[str] = None, dtype: Optional[str] = None, trust_remote_code: bool = False, + peft_model_path: Optional[Path] = None, ): unix_socket_template = "unix://{}-{}" if sharded: @@ -146,7 +148,7 @@ def serve( try: model = get_model( - model_id, revision, sharded, quantize, dtype, trust_remote_code + model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_path ) except Exception: logger.exception("Error when initializing model") @@ -194,5 +196,5 @@ def serve( await server.stop(0) asyncio.run( - serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) + serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_path) )