mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Cleaner peft code.
This commit is contained in:
parent
1569558750
commit
9d5a018fac
@ -122,7 +122,7 @@ def download_weights(
|
||||
if not is_local_model:
|
||||
try:
|
||||
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
|
||||
utils.download_and_unload_peft(model_id, revision, adapter_config_filename, trust_remote_code=trust_remote_code)
|
||||
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
|
||||
except utils.LocalEntryNotFoundError:
|
||||
pass
|
||||
|
||||
|
@ -1,65 +1,45 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from loguru import logger
|
||||
import torch
|
||||
|
||||
from transformers.models.auto import modeling_auto
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoConfig, AutoModel, AutoTokenizer
|
||||
from peft import PeftModel
|
||||
from transformers import AutoTokenizer
|
||||
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
|
||||
|
||||
def download_and_unload_peft(model_id, revision, adapter_config_filename, trust_remote_code):
|
||||
def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||
torch_dtype = torch.float16
|
||||
|
||||
logger.info("Peft model detected.")
|
||||
peft_model_id = model_id
|
||||
with open(adapter_config_filename, "r") as f:
|
||||
adapter_config = json.load(f)
|
||||
model_id = adapter_config["base_model_name_or_path"]
|
||||
logger.info(f"Merging the lora weights {repr(peft_model_id)} into the base model {repr(model_id)}")
|
||||
config = AutoConfig.from_pretrained(model_id, revision=revision, trust_remote_code=trust_remote_code)
|
||||
model_type = config.model_type
|
||||
logger.info(f"Starting to load the base model {repr(model_id)}, this may take a while with no feedback.")
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code
|
||||
)
|
||||
elif model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||
base_model = AutoModelForSeq2SeqLM(
|
||||
model_id,
|
||||
revision=revision,
|
||||
)
|
||||
else:
|
||||
# We have no idea just try either
|
||||
logger.info("Loading the model it might take a while without feedback")
|
||||
try:
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
except Exception:
|
||||
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
model = AutoPeftModelForSeq2SeqLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
logger.info(f"Loaded.")
|
||||
logger.info(f"Merging the lora weights.")
|
||||
|
||||
logger.info(f"Merging the lora weights {repr(peft_model_id)} into the base model {repr(model_id)}")
|
||||
base_model_id = model.peft_config["default"].base_model_name_or_path
|
||||
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
peft_model_id,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
os.makedirs(peft_model_id, exist_ok=True)
|
||||
cache_dir = peft_model_id
|
||||
os.makedirs(model_id, exist_ok=True)
|
||||
cache_dir = model_id
|
||||
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
|
||||
model.save_pretrained(cache_dir, safe_serialization=True)
|
||||
config.save_pretrained(cache_dir)
|
||||
model.config.save_pretrained(cache_dir)
|
||||
tokenizer.save_pretrained(cache_dir)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user