This commit is contained in:
Geoffrey Angus 2023-08-15 15:37:54 -07:00
parent ab0937b90c
commit ae5beb9d7b

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import List, Dict, Set, Tuple
import torch
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from loguru import logger
from peft import LoraConfig
from peft.utils import transpose
@ -76,7 +77,8 @@ def create_merged_weight_files(
adapter_config = LoraConfig.from_pretrained(adapter_id)
if adapter_config.base_model_name_or_path != model_id:
raise ValueError(f"Adapter {adapter_id} is not compatible with model {model_id}")
raise ValueError(f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.")
# load adapter weights from all shards (should have relatively small memory footprint)
adapter_weights = {}
@ -84,15 +86,21 @@ def create_merged_weight_files(
adapter_weights.update(load_file(filename))
remaining_adapter_weight_names = set(adapter_weights.keys())
merged_weight_directory = f"/data/{adapter_id.replace('/', '--')}-merged/"
merged_weight_directory = Path(HUGGINGFACE_HUB_CACHE) / f"models--{adapter_id.replace('/', '--')}-merged"
# just grab the existing files if they already exist and return immediately
if os.path.exists(merged_weight_directory):
logger.info("Merged weight files already exist, skipping merge computation.")
return weight_files(merged_weight_directory)
else:
logger.info("Merged weight files do not exist, computing merge.")
os.makedirs(merged_weight_directory)
os.makedirs(merged_weight_directory)
merged_weight_filenames = []
for filename in model_weight_filenames:
for i, filename in enumerate(model_weight_filenames):
logger.info(
f"Merging adapter weights into model weights in "
f"{filename} ({i+1} / {len(model_weight_filenames)})..."
)
model_weights = load_file(filename)
merged_weights, processed_adapter_weight_names = merge_adapter_weights(
model_weights, adapter_weights, adapter_config)
@ -110,6 +118,8 @@ def create_merged_weight_files(
for lora_name in remaining_adapter_weight_names:
logger.warning("\t" + lora_name)
logger.info(
f"Finished merging adapter weights. Merged weight files saved to: {merged_weight_directory}")
return merged_weight_filenames