diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 3037f21c..405db350 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -122,7 +122,7 @@ def download_weights( if not is_local_model: try: adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json") - utils.download_and_unload_peft(model_id, revision, adapter_config_filename, trust_remote_code=trust_remote_code) + utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code) except utils.LocalEntryNotFoundError: pass diff --git a/server/text_generation_server/utils/peft.py b/server/text_generation_server/utils/peft.py index 68744e38..be1f9444 100644 --- a/server/text_generation_server/utils/peft.py +++ b/server/text_generation_server/utils/peft.py @@ -1,65 +1,45 @@ import os -import sys import json from loguru import logger +import torch -from transformers.models.auto import modeling_auto -from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoConfig, AutoModel, AutoTokenizer -from peft import PeftModel +from transformers import AutoTokenizer +from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM -def download_and_unload_peft(model_id, revision, adapter_config_filename, trust_remote_code): +def download_and_unload_peft(model_id, revision, trust_remote_code): + torch_dtype = torch.float16 logger.info("Peft model detected.") - peft_model_id = model_id - with open(adapter_config_filename, "r") as f: - adapter_config = json.load(f) - model_id = adapter_config["base_model_name_or_path"] - logger.info(f"Merging the lora weights {repr(peft_model_id)} into the base model {repr(model_id)}") - config = AutoConfig.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code) - model_type = config.model_type - logger.info(f"Starting to load the base model {repr(model_id)}, this may take a while with no feedback.") - if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - base_model = AutoModelForCausalLM.from_pretrained( + logger.info("Loading the model it might take a while without feedback") + try: + model = AutoPeftModelForCausalLM.from_pretrained( model_id, revision=revision, - trust_remote_code=trust_remote_code + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, ) - elif model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: - base_model = AutoModelForSeq2SeqLM( + except Exception: + model = AutoPeftModelForSeq2SeqLM.from_pretrained( model_id, revision=revision, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, ) - else: - # We have no idea just try either - try: - base_model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code - ) - except Exception: - base_model = AutoModelForSeq2SeqLM.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code - ) logger.info(f"Loaded.") + logger.info(f"Merging the lora weights.") - logger.info(f"Merging the lora weights {repr(peft_model_id)} into the base model {repr(model_id)}") + base_model_id = model.peft_config["default"].base_model_name_or_path - model = PeftModel.from_pretrained( - base_model, - peft_model_id, - trust_remote_code=trust_remote_code, - ) model = model.merge_and_unload() - os.makedirs(peft_model_id, exist_ok=True) - cache_dir = peft_model_id + os.makedirs(model_id, exist_ok=True) + cache_dir = model_id logger.info(f"Saving the newly created merged model to {cache_dir}") - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(base_model_id) model.save_pretrained(cache_dir, safe_serialization=True) - config.save_pretrained(cache_dir) + model.config.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir)