mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
load peft from cli
This commit is contained in:
parent
aba56c1343
commit
1659b871b6
@ -24,6 +24,6 @@ install: gen-server install-torch
|
|||||||
pip install -e ".[bnb, accelerate]"
|
pip install -e ".[bnb, accelerate]"
|
||||||
|
|
||||||
run-dev:
|
run-dev:
|
||||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes
|
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=1 text_generation_server/cli.py serve /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf --quantize bitsandbytes --peft-model-path /mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-function-calling-adapters-v2
|
||||||
export-requirements:
|
export-requirements:
|
||||||
poetry export -o requirements.txt -E bnb -E quantize --without-hashes
|
poetry export -o requirements.txt -E bnb -E quantize --without-hashes
|
||||||
|
@ -33,6 +33,7 @@ def serve(
|
|||||||
dtype: Optional[Dtype] = None,
|
dtype: Optional[Dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
uds_path: Path = "/tmp/text-generation-server",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
|
peft_model_path: Optional[Path] = None,
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
otlp_endpoint: Optional[str] = None,
|
otlp_endpoint: Optional[str] = None,
|
||||||
@ -79,7 +80,7 @@ def serve(
|
|||||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||||
)
|
)
|
||||||
server.serve(
|
server.serve(
|
||||||
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path
|
model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path, peft_model_path
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from loguru import logger
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
@ -75,6 +76,7 @@ def get_model(
|
|||||||
quantize: Optional[str],
|
quantize: Optional[str],
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
|
peft_model_path: Optional[Path],
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
@ -180,6 +182,7 @@ def get_model(
|
|||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
peft_model_path=peft_model_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "llama":
|
elif model_type == "llama":
|
||||||
@ -200,6 +203,7 @@ def get_model(
|
|||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
peft_model_path=peft_model_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
|
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
|
||||||
@ -298,6 +302,7 @@ def get_model(
|
|||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
peft_model_path=peft_model_path,
|
||||||
)
|
)
|
||||||
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
||||||
return Seq2SeqLM(
|
return Seq2SeqLM(
|
||||||
|
@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
|
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, BitsAndBytesConfig
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
from pathlib import Path
|
||||||
from peft import PeftModelForCausalLM, get_peft_config, PeftConfig
|
from peft import PeftModelForCausalLM, get_peft_config, PeftConfig
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
@ -458,6 +459,7 @@ class CausalLM(Model):
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
peft_model_path: Optional[Path] = None ,
|
||||||
):
|
):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -484,8 +486,7 @@ class CausalLM(Model):
|
|||||||
bnb_4bit_compute_dtype=torch.float16
|
bnb_4bit_compute_dtype=torch.float16
|
||||||
)
|
)
|
||||||
|
|
||||||
has_peft_model = True
|
has_peft_model = peft_model_path is not None
|
||||||
peft_model_id_or_path = "/mnt/TOFU/HF_MODELS/Llama-2-7b-chat-hf-instruct-pl-lora_adapter_model"
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
@ -497,8 +498,11 @@ class CausalLM(Model):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if has_peft_model:
|
if has_peft_model:
|
||||||
with open(f'{peft_model_id_or_path}/adapter_config.json') as config_file:
|
with open(f'{peft_model_path}/adapter_config.json') as config_file:
|
||||||
config = json.load(config_file)
|
config = json.load(config_file)
|
||||||
|
# patch to a local path
|
||||||
|
config["base_model_name_or_path"] = model_id
|
||||||
|
# conver to peft model
|
||||||
peft_config = get_peft_config(config)
|
peft_config = get_peft_config(config)
|
||||||
model = PeftModelForCausalLM(model, peft_config)
|
model = PeftModelForCausalLM(model, peft_config)
|
||||||
## Llama does not have a load_adapter method - we need to think about hot swapping here and implement this for Llama
|
## Llama does not have a load_adapter method - we need to think about hot swapping here and implement this for Llama
|
||||||
|
@ -124,6 +124,7 @@ def serve(
|
|||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
uds_path: Path,
|
uds_path: Path,
|
||||||
|
peft_model_path: Optional[Path]
|
||||||
):
|
):
|
||||||
async def serve_inner(
|
async def serve_inner(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -132,6 +133,7 @@ def serve(
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
peft_model_path: Optional[Path] = None,
|
||||||
):
|
):
|
||||||
unix_socket_template = "unix://{}-{}"
|
unix_socket_template = "unix://{}-{}"
|
||||||
if sharded:
|
if sharded:
|
||||||
@ -146,7 +148,7 @@ def serve(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
model = get_model(
|
model = get_model(
|
||||||
model_id, revision, sharded, quantize, dtype, trust_remote_code
|
model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_path
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error when initializing model")
|
logger.exception("Error when initializing model")
|
||||||
@ -194,5 +196,5 @@ def serve(
|
|||||||
await server.stop(0)
|
await server.stop(0)
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code)
|
serve_inner(model_id, revision, sharded, quantize, dtype, trust_remote_code, peft_model_path)
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user