load peft from cli

This commit is contained in:
Chris 2023-08-27 20:56:53 +02:00
parent aba56c1343
commit 1659b871b6
5 changed files with 19 additions and 7 deletions

View File

@ -24,6 +24,6 @@ install: gen-server install-torch
pip install -e ".[bnb, accelerate]" pip install -e ".[bnb, accelerate]"
run-dev: run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes --peft-model-path /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-function-calling-adapters-v2
export-requirements: export-requirements:
poetry export -o requirements.txt -E bnb -E quantize --without-hashes poetry export -o requirements.txt -E bnb -E quantize --without-hashes

View File

@ -33,6 +33,7 @@ def serve(
dtype: Optional[Dtype] = None, dtype: Optional[Dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
uds_path: Path = "/tmp/text-generation-server", uds_path: Path = "/tmp/text-generation-server",
peft_model_path: Optional[Path] = None,
logger_level: str = "INFO", logger_level: str = "INFO",
json_output: bool = False, json_output: bool = False,
otlp_endpoint: Optional[str] = None, otlp_endpoint: Optional[str] = None,
@ -79,7 +80,7 @@ def serve(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
) )
server.serve( server.serve(
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, peft_model_path
) )

View File

@ -5,6 +5,7 @@ from loguru import logger
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto from transformers.models.auto import modeling_auto
from typing import Optional from typing import Optional
from pathlib import Path
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
@ -75,6 +76,7 @@ def get_model(
quantize: Optional[str], quantize: Optional[str],
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
peft_model_path: Optional[Path],
) -> Model: ) -> Model:
if dtype is None: if dtype is None:
dtype = torch.float16 dtype = torch.float16
@ -180,6 +182,7 @@ def get_model(
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
peft_model_path=peft_model_path,
) )
elif model_type == "llama": elif model_type == "llama":
@ -200,6 +203,7 @@ def get_model(
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
peft_model_path=peft_model_path,
) )
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]: if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
@ -298,6 +302,7 @@ def get_model(
quantize=quantize, quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
peft_model_path=peft_model_path,
) )
if "AutoModelForSeq2SeqLM" in auto_map.keys(): if "AutoModelForSeq2SeqLM" in auto_map.keys():
return Seq2SeqLM( return Seq2SeqLM(

View File

@ -6,6 +6,7 @@ from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from pathlib import Path
from peft import PeftModelForCausalLM, get_peft_config, PeftConfig from peft import PeftModelForCausalLM, get_peft_config, PeftConfig
from text_generation_server.models import Model from text_generation_server.models import Model
@ -458,6 +459,7 @@ class CausalLM(Model):
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
peft_model_path: Optional[Path] = None ,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
@ -484,8 +486,7 @@ class CausalLM(Model):
bnb_4bit_compute_dtype=torch.float16 bnb_4bit_compute_dtype=torch.float16
) )
has_peft_model = True has_peft_model = peft_model_path is not None
peft_model_id_or_path = "/mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-instruct-pl-lora_adapter_model"
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
@ -497,8 +498,11 @@ class CausalLM(Model):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if has_peft_model: if has_peft_model:
with open(f'{peft_model_id_or_path}/adapter_config.json') as config_file: with open(f'{peft_model_path}/adapter_config.json') as config_file:
config = json.load(config_file) config = json.load(config_file)
# patch to a local path
config["base_model_name_or_path"] = model_id
# conver to peft model
peft_config = get_peft_config(config) peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config) model = PeftModelForCausalLM(model, peft_config)
## Llama does not have a load_adapter method - we need to think about hot swapping here and implement this for Llama ## Llama does not have a load_adapter method - we need to think about hot swapping here and implement this for Llama

View File

@ -124,6 +124,7 @@ def serve(
dtype: Optional[str], dtype: Optional[str],
trust_remote_code: bool, trust_remote_code: bool,
uds_path: Path, uds_path: Path,
peft_model_path: Optional[Path]
): ):
async def serve_inner( async def serve_inner(
model_id: str, model_id: str,
@ -132,6 +133,7 @@ def serve(
quantize: Optional[str] = None, quantize: Optional[str] = None,
dtype: Optional[str] = None, dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
peft_model_path: Optional[Path] = None,
): ):
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
if sharded: if sharded:
@ -146,7 +148,7 @@ def serve(
try: try:
model = get_model( model = get_model(
model_id, revision, sharded, quantize, dtype, trust_remote_code model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_path
) )
except Exception: except Exception:
logger.exception("Error when initializing model") logger.exception("Error when initializing model")
@ -194,5 +196,5 @@ def serve(
await server.stop(0) await server.stop(0)
asyncio.run( asyncio.run(
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code) serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_path)
) )