mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
load peft from cli
This commit is contained in:
parent
aba56c1343
commit
1659b871b6
@ -24,6 +24,6 @@ install: gen-server install-torch
|
||||
pip install -e ".[bnb, accelerate]"
|
||||
|
||||
run-dev:
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes --peft-model-path /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-function-calling-adapters-v2
|
||||
export-requirements:
|
||||
poetry export -o requirements.txt -E bnb -E quantize --without-hashes
|
||||
|
@ -33,6 +33,7 @@ def serve(
|
||||
dtype: Optional[Dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
uds_path: Path = "/tmp/text-generation-server",
|
||||
peft_model_path: Optional[Path] = None,
|
||||
logger_level: str = "INFO",
|
||||
json_output: bool = False,
|
||||
otlp_endpoint: Optional[str] = None,
|
||||
@ -79,7 +80,7 @@ def serve(
|
||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||
)
|
||||
server.serve(
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, peft_model_path
|
||||
)
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ from loguru import logger
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import modeling_auto
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
@ -75,6 +76,7 @@ def get_model(
|
||||
quantize: Optional[str],
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
peft_model_path: Optional[Path],
|
||||
) -> Model:
|
||||
if dtype is None:
|
||||
dtype = torch.float16
|
||||
@ -180,6 +182,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
peft_model_path=peft_model_path,
|
||||
)
|
||||
|
||||
elif model_type == "llama":
|
||||
@ -200,6 +203,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
peft_model_path=peft_model_path,
|
||||
)
|
||||
|
||||
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
|
||||
@ -298,6 +302,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
peft_model_path=peft_model_path,
|
||||
)
|
||||
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
||||
return Seq2SeqLM(
|
||||
|
@ -6,6 +6,7 @@ from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
from pathlib import Path
|
||||
from peft import PeftModelForCausalLM, get_peft_config, PeftConfig
|
||||
|
||||
from text_generation_server.models import Model
|
||||
@ -458,6 +459,7 @@ class CausalLM(Model):
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
peft_model_path: Optional[Path] = None ,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
@ -484,8 +486,7 @@ class CausalLM(Model):
|
||||
bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
|
||||
has_peft_model = True
|
||||
peft_model_id_or_path = "/mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-instruct-pl-lora_adapter_model"
|
||||
has_peft_model = peft_model_path is not None
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
@ -497,8 +498,11 @@ class CausalLM(Model):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if has_peft_model:
|
||||
with open(f'{peft_model_id_or_path}/adapter_config.json') as config_file:
|
||||
with open(f'{peft_model_path}/adapter_config.json') as config_file:
|
||||
config = json.load(config_file)
|
||||
# patch to a local path
|
||||
config["base_model_name_or_path"] = model_id
|
||||
# conver to peft model
|
||||
peft_config = get_peft_config(config)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
## Llama does not have a load_adapter method - we need to think about hot swapping here and implement this for Llama
|
||||
|
@ -124,6 +124,7 @@ def serve(
|
||||
dtype: Optional[str],
|
||||
trust_remote_code: bool,
|
||||
uds_path: Path,
|
||||
peft_model_path: Optional[Path]
|
||||
):
|
||||
async def serve_inner(
|
||||
model_id: str,
|
||||
@ -132,6 +133,7 @@ def serve(
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
peft_model_path: Optional[Path] = None,
|
||||
):
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
if sharded:
|
||||
@ -146,7 +148,7 @@ def serve(
|
||||
|
||||
try:
|
||||
model = get_model(
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_path
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error when initializing model")
|
||||
@ -194,5 +196,5 @@ def serve(
|
||||
await server.stop(0)
|
||||
|
||||
asyncio.run(
|
||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
|
||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_path)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user