Cleaner peft code.

This commit is contained in:
Nicolas Patry 2023-08-03 12:41:13 +00:00
parent 1569558750
commit 9d5a018fac
2 changed files with 23 additions and 43 deletions

View File

@ -122,7 +122,7 @@ def download_weights(
if not is_local_model: if not is_local_model:
try: try:
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json") adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
utils.download_and_unload_peft(model_id, revision, adapter_config_filename, trust_remote_code=trust_remote_code) utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
except utils.LocalEntryNotFoundError: except utils.LocalEntryNotFoundError:
pass pass

View File

@ -1,65 +1,45 @@
import os import os
import sys
import json import json
from loguru import logger from loguru import logger
import torch
from transformers.models.auto import modeling_auto from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoConfig, AutoModel, AutoTokenizer from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
from peft import PeftModel
def download_and_unload_peft(model_id, revision, adapter_config_filename, trust_remote_code): def download_and_unload_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16
logger.info("Peft model detected.") logger.info("Peft model detected.")
peft_model_id = model_id logger.info("Loading the model it might take a while without feedback")
with open(adapter_config_filename, "r") as f:
adapter_config = json.load(f)
model_id = adapter_config["base_model_name_or_path"]
logger.info(f"Merging the lora weights {repr(peft_model_id)} into the base model {repr(model_id)}")
config = AutoConfig.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code)
model_type = config.model_type
logger.info(f"Starting to load the base model {repr(model_id)}, this may take a while with no feedback.")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code
)
elif model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
base_model = AutoModelForSeq2SeqLM(
model_id,
revision=revision,
)
else:
# We have no idea just try either
try: try:
base_model = AutoModelForCausalLM.from_pretrained( model = AutoPeftModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
) )
except Exception: except Exception:
base_model = AutoModelForSeq2SeqLM.from_pretrained( model = AutoPeftModelForSeq2SeqLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
) )
logger.info(f"Loaded.") logger.info(f"Loaded.")
logger.info(f"Merging the lora weights.")
logger.info(f"Merging the lora weights {repr(peft_model_id)} into the base model {repr(model_id)}") base_model_id = model.peft_config["default"].base_model_name_or_path
model = PeftModel.from_pretrained(
base_model,
peft_model_id,
trust_remote_code=trust_remote_code,
)
model = model.merge_and_unload() model = model.merge_and_unload()
os.makedirs(peft_model_id, exist_ok=True) os.makedirs(model_id, exist_ok=True)
cache_dir = peft_model_id cache_dir = model_id
logger.info(f"Saving the newly created merged model to {cache_dir}") logger.info(f"Saving the newly created merged model to {cache_dir}")
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model.save_pretrained(cache_dir, safe_serialization=True) model.save_pretrained(cache_dir, safe_serialization=True)
config.save_pretrained(cache_dir) model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir)