Cleaner peft code.

This commit is contained in:
Nicolas Patry 2023-08-03 12:41:13 +00:00
parent 1569558750
commit 9d5a018fac
2 changed files with 23 additions and 43 deletions

View File

@ -122,7 +122,7 @@ def download_weights(
if not is_local_model:
try:
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
utils.download_and_unload_peft(model_id, revision, adapter_config_filename, trust_remote_code=trust_remote_code)
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
except utils.LocalEntryNotFoundError:
pass

View File

@ -1,65 +1,45 @@
import os
import sys
import json
from loguru import logger
import torch
from transformers.models.auto import modeling_auto
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoConfig, AutoModel, AutoTokenizer
from peft import PeftModel
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
def download_and_unload_peft(model_id, revision, adapter_config_filename, trust_remote_code):
def download_and_unload_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16
logger.info("Peft model detected.")
peft_model_id = model_id
with open(adapter_config_filename, "r") as f:
adapter_config = json.load(f)
model_id = adapter_config["base_model_name_or_path"]
logger.info(f"Merging the lora weights {repr(peft_model_id)} into the base model {repr(model_id)}")
config = AutoConfig.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code)
model_type = config.model_type
logger.info(f"Starting to load the base model {repr(model_id)}, this may take a while with no feedback.")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code
)
elif model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
base_model = AutoModelForSeq2SeqLM(
model_id,
revision=revision,
)
else:
# We have no idea just try either
logger.info("Loading the model it might take a while without feedback")
try:
base_model = AutoModelForCausalLM.from_pretrained(
model = AutoPeftModelForCausalLM.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
except Exception:
base_model = AutoModelForSeq2SeqLM.from_pretrained(
model = AutoPeftModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
logger.info(f"Loaded.")
logger.info(f"Merging the lora weights.")
logger.info(f"Merging the lora weights {repr(peft_model_id)} into the base model {repr(model_id)}")
base_model_id = model.peft_config["default"].base_model_name_or_path
model = PeftModel.from_pretrained(
base_model,
peft_model_id,
trust_remote_code=trust_remote_code,
)
model = model.merge_and_unload()
os.makedirs(peft_model_id, exist_ok=True)
cache_dir = peft_model_id
os.makedirs(model_id, exist_ok=True)
cache_dir = model_id
logger.info(f"Saving the newly created merged model to {cache_dir}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model.save_pretrained(cache_dir, safe_serialization=True)
config.save_pretrained(cache_dir)
model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir)