mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
add logs
This commit is contained in:
parent
ab0937b90c
commit
ae5beb9d7b
@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
from typing import List, Dict, Set, Tuple
|
from typing import List, Dict, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from peft import LoraConfig
|
from peft import LoraConfig
|
||||||
from peft.utils import transpose
|
from peft.utils import transpose
|
||||||
@ -76,7 +77,8 @@ def create_merged_weight_files(
|
|||||||
|
|
||||||
adapter_config = LoraConfig.from_pretrained(adapter_id)
|
adapter_config = LoraConfig.from_pretrained(adapter_id)
|
||||||
if adapter_config.base_model_name_or_path != model_id:
|
if adapter_config.base_model_name_or_path != model_id:
|
||||||
raise ValueError(f"Adapter {adapter_id} is not compatible with model {model_id}")
|
raise ValueError(f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
|
||||||
|
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.")
|
||||||
|
|
||||||
# load adapter weights from all shards (should have relatively small memory footprint)
|
# load adapter weights from all shards (should have relatively small memory footprint)
|
||||||
adapter_weights = {}
|
adapter_weights = {}
|
||||||
@ -84,15 +86,21 @@ def create_merged_weight_files(
|
|||||||
adapter_weights.update(load_file(filename))
|
adapter_weights.update(load_file(filename))
|
||||||
remaining_adapter_weight_names = set(adapter_weights.keys())
|
remaining_adapter_weight_names = set(adapter_weights.keys())
|
||||||
|
|
||||||
merged_weight_directory = f"/data/{adapter_id.replace('/', '--')}-merged/"
|
merged_weight_directory = Path(HUGGINGFACE_HUB_CACHE) / f"models--{adapter_id.replace('/', '--')}-merged"
|
||||||
# just grab the existing files if they already exist and return immediately
|
# just grab the existing files if they already exist and return immediately
|
||||||
if os.path.exists(merged_weight_directory):
|
if os.path.exists(merged_weight_directory):
|
||||||
logger.info("Merged weight files already exist, skipping merge computation.")
|
logger.info("Merged weight files already exist, skipping merge computation.")
|
||||||
return weight_files(merged_weight_directory)
|
return weight_files(merged_weight_directory)
|
||||||
|
else:
|
||||||
|
logger.info("Merged weight files do not exist, computing merge.")
|
||||||
|
os.makedirs(merged_weight_directory)
|
||||||
|
|
||||||
os.makedirs(merged_weight_directory)
|
|
||||||
merged_weight_filenames = []
|
merged_weight_filenames = []
|
||||||
for filename in model_weight_filenames:
|
for i, filename in enumerate(model_weight_filenames):
|
||||||
|
logger.info(
|
||||||
|
f"Merging adapter weights into model weights in "
|
||||||
|
f"{filename} ({i+1} / {len(model_weight_filenames)})..."
|
||||||
|
)
|
||||||
model_weights = load_file(filename)
|
model_weights = load_file(filename)
|
||||||
merged_weights, processed_adapter_weight_names = merge_adapter_weights(
|
merged_weights, processed_adapter_weight_names = merge_adapter_weights(
|
||||||
model_weights, adapter_weights, adapter_config)
|
model_weights, adapter_weights, adapter_config)
|
||||||
@ -110,6 +118,8 @@ def create_merged_weight_files(
|
|||||||
for lora_name in remaining_adapter_weight_names:
|
for lora_name in remaining_adapter_weight_names:
|
||||||
logger.warning("\t" + lora_name)
|
logger.warning("\t" + lora_name)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Finished merging adapter weights. Merged weight files saved to: {merged_weight_directory}")
|
||||||
return merged_weight_filenames
|
return merged_weight_filenames
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user