mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Cleaner peft code.
This commit is contained in:
parent
1569558750
commit
9d5a018fac
@ -122,7 +122,7 @@ def download_weights(
|
|||||||
if not is_local_model:
|
if not is_local_model:
|
||||||
try:
|
try:
|
||||||
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
|
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
|
||||||
utils.download_and_unload_peft(model_id, revision, adapter_config_filename, trust_remote_code=trust_remote_code)
|
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
|
||||||
except utils.LocalEntryNotFoundError:
|
except utils.LocalEntryNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1,65 +1,45 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
import torch
|
||||||
|
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers import AutoTokenizer
|
||||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoConfig, AutoModel, AutoTokenizer
|
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
|
||||||
from peft import PeftModel
|
|
||||||
|
|
||||||
def download_and_unload_peft(model_id, revision, adapter_config_filename, trust_remote_code):
|
def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
logger.info("Peft model detected.")
|
logger.info("Peft model detected.")
|
||||||
peft_model_id = model_id
|
logger.info("Loading the model it might take a while without feedback")
|
||||||
with open(adapter_config_filename, "r") as f:
|
try:
|
||||||
adapter_config = json.load(f)
|
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||||
model_id = adapter_config["base_model_name_or_path"]
|
|
||||||
logger.info(f"Merging the lora weights {repr(peft_model_id)} into the base model {repr(model_id)}")
|
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code)
|
|
||||||
model_type = config.model_type
|
|
||||||
logger.info(f"Starting to load the base model {repr(model_id)}, this may take a while with no feedback.")
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
|
||||||
base_model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
trust_remote_code=trust_remote_code
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
)
|
)
|
||||||
elif model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
except Exception:
|
||||||
base_model = AutoModelForSeq2SeqLM(
|
model = AutoPeftModelForSeq2SeqLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# We have no idea just try either
|
|
||||||
try:
|
|
||||||
base_model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
logger.info(f"Loaded.")
|
logger.info(f"Loaded.")
|
||||||
|
logger.info(f"Merging the lora weights.")
|
||||||
|
|
||||||
logger.info(f"Merging the lora weights {repr(peft_model_id)} into the base model {repr(model_id)}")
|
base_model_id = model.peft_config["default"].base_model_name_or_path
|
||||||
|
|
||||||
model = PeftModel.from_pretrained(
|
|
||||||
base_model,
|
|
||||||
peft_model_id,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
os.makedirs(peft_model_id, exist_ok=True)
|
os.makedirs(model_id, exist_ok=True)
|
||||||
cache_dir = peft_model_id
|
cache_dir = model_id
|
||||||
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
|
||||||
model.save_pretrained(cache_dir, safe_serialization=True)
|
model.save_pretrained(cache_dir, safe_serialization=True)
|
||||||
config.save_pretrained(cache_dir)
|
model.config.save_pretrained(cache_dir)
|
||||||
tokenizer.save_pretrained(cache_dir)
|
tokenizer.save_pretrained(cache_dir)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user