diff --git a/server/text_generation_server/utils/adapter.py b/server/text_generation_server/utils/adapter.py index 9f6d4960..76a87ad7 100644 --- a/server/text_generation_server/utils/adapter.py +++ b/server/text_generation_server/utils/adapter.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import List, Dict, Set, Tuple import torch +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from loguru import logger from peft import LoraConfig from peft.utils import transpose @@ -76,7 +77,8 @@ def create_merged_weight_files( adapter_config = LoraConfig.from_pretrained(adapter_id) if adapter_config.base_model_name_or_path != model_id: - raise ValueError(f"Adapter {adapter_id} is not compatible with model {model_id}") + raise ValueError(f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. " + f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.") # load adapter weights from all shards (should have relatively small memory footprint) adapter_weights = {} @@ -84,15 +86,21 @@ def create_merged_weight_files( adapter_weights.update(load_file(filename)) remaining_adapter_weight_names = set(adapter_weights.keys()) - merged_weight_directory = f"/data/{adapter_id.replace('/', '--')}-merged/" + merged_weight_directory = Path(HUGGINGFACE_HUB_CACHE) / f"models--{adapter_id.replace('/', '--')}-merged" # just grab the existing files if they already exist and return immediately if os.path.exists(merged_weight_directory): logger.info("Merged weight files already exist, skipping merge computation.") return weight_files(merged_weight_directory) + else: + logger.info("Merged weight files do not exist, computing merge.") + os.makedirs(merged_weight_directory) - os.makedirs(merged_weight_directory) merged_weight_filenames = [] - for filename in model_weight_filenames: + for i, filename in enumerate(model_weight_filenames): + logger.info( + f"Merging adapter weights into model weights in " + f"{filename} ({i+1} / {len(model_weight_filenames)})..." + ) model_weights = load_file(filename) merged_weights, processed_adapter_weight_names = merge_adapter_weights( model_weights, adapter_weights, adapter_config) @@ -110,6 +118,8 @@ def create_merged_weight_files( for lora_name in remaining_adapter_weight_names: logger.warning("\t" + lora_name) + logger.info( + f"Finished merging adapter weights. Merged weight files saved to: {merged_weight_directory}") return merged_weight_filenames