This commit is contained in:
Geoffrey Angus 2023-08-15 15:37:54 -07:00
parent ab0937b90c
commit ae5beb9d7b

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import List, Dict, Set, Tuple from typing import List, Dict, Set, Tuple
import torch import torch
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from loguru import logger from loguru import logger
from peft import LoraConfig from peft import LoraConfig
from peft.utils import transpose from peft.utils import transpose
@ -76,7 +77,8 @@ def create_merged_weight_files(
adapter_config = LoraConfig.from_pretrained(adapter_id) adapter_config = LoraConfig.from_pretrained(adapter_id)
if adapter_config.base_model_name_or_path != model_id: if adapter_config.base_model_name_or_path != model_id:
raise ValueError(f"Adapter {adapter_id} is not compatible with model {model_id}") raise ValueError(f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.")
# load adapter weights from all shards (should have relatively small memory footprint) # load adapter weights from all shards (should have relatively small memory footprint)
adapter_weights = {} adapter_weights = {}
@ -84,15 +86,21 @@ def create_merged_weight_files(
adapter_weights.update(load_file(filename)) adapter_weights.update(load_file(filename))
remaining_adapter_weight_names = set(adapter_weights.keys()) remaining_adapter_weight_names = set(adapter_weights.keys())
merged_weight_directory = f"/data/{adapter_id.replace('/', '--')}-merged/" merged_weight_directory = Path(HUGGINGFACE_HUB_CACHE) / f"models--{adapter_id.replace('/', '--')}-merged"
# just grab the existing files if they already exist and return immediately # just grab the existing files if they already exist and return immediately
if os.path.exists(merged_weight_directory): if os.path.exists(merged_weight_directory):
logger.info("Merged weight files already exist, skipping merge computation.") logger.info("Merged weight files already exist, skipping merge computation.")
return weight_files(merged_weight_directory) return weight_files(merged_weight_directory)
else:
logger.info("Merged weight files do not exist, computing merge.")
os.makedirs(merged_weight_directory)
os.makedirs(merged_weight_directory)
merged_weight_filenames = [] merged_weight_filenames = []
for filename in model_weight_filenames: for i, filename in enumerate(model_weight_filenames):
logger.info(
f"Merging adapter weights into model weights in "
f"{filename} ({i+1} / {len(model_weight_filenames)})..."
)
model_weights = load_file(filename) model_weights = load_file(filename)
merged_weights, processed_adapter_weight_names = merge_adapter_weights( merged_weights, processed_adapter_weight_names = merge_adapter_weights(
model_weights, adapter_weights, adapter_config) model_weights, adapter_weights, adapter_config)
@ -110,6 +118,8 @@ def create_merged_weight_files(
for lora_name in remaining_adapter_weight_names: for lora_name in remaining_adapter_weight_names:
logger.warning("\t" + lora_name) logger.warning("\t" + lora_name)
logger.info(
f"Finished merging adapter weights. Merged weight files saved to: {merged_weight_directory}")
return merged_weight_filenames return merged_weight_filenames