load peft from cli

This commit is contained in:
Chris 2023-08-27 20:56:53 +02:00
parent aba56c1343
commit 1659b871b6
5 changed files with 19 additions and 7 deletions

View File

@ -24,6 +24,6 @@ install: gen-server install-torch
pip install -e ".[bnb, accelerate]"
run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes --peft-model-path /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-function-calling-adapters-v2
export-requirements:
poetry export -o requirements.txt -E bnb -E quantize --without-hashes

View File

@ -33,6 +33,7 @@ def serve(
dtype: Optional[Dtype] = None,
trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server",
peft_model_path: Optional[Path] = None,
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
@ -79,7 +80,7 @@ def serve(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, peft_model_path
)

View File

@ -5,6 +5,7 @@ from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from typing import Optional
from pathlib import Path
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
@ -75,6 +76,7 @@ def get_model(
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
peft_model_path: Optional[Path],
) -> Model:
if dtype is None:
dtype = torch.float16
@ -180,6 +182,7 @@ def get_model(
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
peft_model_path=peft_model_path,
)
elif model_type == "llama":
@ -200,6 +203,7 @@ def get_model(
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
peft_model_path=peft_model_path,
)
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
@ -298,6 +302,7 @@ def get_model(
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
peft_model_path=peft_model_path,
)
if "AutoModelForSeq2SeqLM" in auto_map.keys():
return Seq2SeqLM(

View File

@ -6,6 +6,7 @@ from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
from typing import Optional, Tuple, List, Type, Dict
from pathlib import Path
from peft import PeftModelForCausalLM, get_peft_config, PeftConfig
from text_generation_server.models import Model
@ -458,6 +459,7 @@ class CausalLM(Model):
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
peft_model_path: Optional[Path] = None ,
):
if torch.cuda.is_available():
device = torch.device("cuda")
@ -484,8 +486,7 @@ class CausalLM(Model):
bnb_4bit_compute_dtype=torch.float16
)
has_peft_model = True
peft_model_id_or_path = "/mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-instruct-pl-lora_adapter_model"
has_peft_model = peft_model_path is not None
model = AutoModelForCausalLM.from_pretrained(
model_id,
@ -497,8 +498,11 @@ class CausalLM(Model):
trust_remote_code=trust_remote_code,
)
if has_peft_model:
with open(f'{peft_model_id_or_path}/adapter_config.json') as config_file:
with open(f'{peft_model_path}/adapter_config.json') as config_file:
config = json.load(config_file)
# patch to a local path
config["base_model_name_or_path"] = model_id
# conver to peft model
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
## Llama does not have a load_adapter method - we need to think about hot swapping here and implement this for Llama

View File

@ -124,6 +124,7 @@ def serve(
dtype: Optional[str],
trust_remote_code: bool,
uds_path: Path,
peft_model_path: Optional[Path]
):
async def serve_inner(
model_id: str,
@ -132,6 +133,7 @@ def serve(
quantize: Optional[str] = None,
dtype: Optional[str] = None,
trust_remote_code: bool = False,
peft_model_path: Optional[Path] = None,
):
unix_socket_template = "unix://{}-{}"
if sharded:
@ -146,7 +148,7 @@ def serve(
try:
model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code
model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_path
)
except Exception:
logger.exception("Error when initializing model")
@ -194,5 +196,5 @@ def serve(
await server.stop(0)
asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_path)
)