mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
add logs
This commit is contained in:
parent
ab0937b90c
commit
ae5beb9d7b
@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import List, Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
from loguru import logger
|
||||
from peft import LoraConfig
|
||||
from peft.utils import transpose
|
||||
@ -76,7 +77,8 @@ def create_merged_weight_files(
|
||||
|
||||
adapter_config = LoraConfig.from_pretrained(adapter_id)
|
||||
if adapter_config.base_model_name_or_path != model_id:
|
||||
raise ValueError(f"Adapter {adapter_id} is not compatible with model {model_id}")
|
||||
raise ValueError(f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
|
||||
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.")
|
||||
|
||||
# load adapter weights from all shards (should have relatively small memory footprint)
|
||||
adapter_weights = {}
|
||||
@ -84,15 +86,21 @@ def create_merged_weight_files(
|
||||
adapter_weights.update(load_file(filename))
|
||||
remaining_adapter_weight_names = set(adapter_weights.keys())
|
||||
|
||||
merged_weight_directory = f"/data/{adapter_id.replace('/', '--')}-merged/"
|
||||
merged_weight_directory = Path(HUGGINGFACE_HUB_CACHE) / f"models--{adapter_id.replace('/', '--')}-merged"
|
||||
# just grab the existing files if they already exist and return immediately
|
||||
if os.path.exists(merged_weight_directory):
|
||||
logger.info("Merged weight files already exist, skipping merge computation.")
|
||||
return weight_files(merged_weight_directory)
|
||||
else:
|
||||
logger.info("Merged weight files do not exist, computing merge.")
|
||||
os.makedirs(merged_weight_directory)
|
||||
|
||||
os.makedirs(merged_weight_directory)
|
||||
merged_weight_filenames = []
|
||||
for filename in model_weight_filenames:
|
||||
for i, filename in enumerate(model_weight_filenames):
|
||||
logger.info(
|
||||
f"Merging adapter weights into model weights in "
|
||||
f"{filename} ({i+1} / {len(model_weight_filenames)})..."
|
||||
)
|
||||
model_weights = load_file(filename)
|
||||
merged_weights, processed_adapter_weight_names = merge_adapter_weights(
|
||||
model_weights, adapter_weights, adapter_config)
|
||||
@ -110,6 +118,8 @@ def create_merged_weight_files(
|
||||
for lora_name in remaining_adapter_weight_names:
|
||||
logger.warning("\t" + lora_name)
|
||||
|
||||
logger.info(
|
||||
f"Finished merging adapter weights. Merged weight files saved to: {merged_weight_directory}")
|
||||
return merged_weight_filenames
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user