fix: fix local loading for .bin models (#1419)

This commit is contained in:
OlivierDehaene 2024-01-09 15:21:00 +01:00 committed by GitHub
parent 3f9b3f4539
commit 564f2a3b75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 4 deletions

View File

@ -198,7 +198,7 @@ def download_weights(
if not extension == ".safetensors" or not auto_convert: if not extension == ".safetensors" or not auto_convert:
raise e raise e
else: elif (Path(model_id) / "adapter_config.json").exists():
# Try to load as a local PEFT model # Try to load as a local PEFT model
try: try:
utils.download_and_unload_peft( utils.download_and_unload_peft(

View File

@ -10,8 +10,7 @@ from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
def download_and_unload_peft(model_id, revision, trust_remote_code): def download_and_unload_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16 torch_dtype = torch.float16
logger.info("Peft model detected.") logger.info("Trying to load a Peft model. It might take a while without feedback")
logger.info("Loading the model it might take a while without feedback")
try: try:
model = AutoPeftModelForCausalLM.from_pretrained( model = AutoPeftModelForCausalLM.from_pretrained(
model_id, model_id,
@ -28,7 +27,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
) )
logger.info(f"Loaded.") logger.info("Peft model detected.")
logger.info(f"Merging the lora weights.") logger.info(f"Merging the lora weights.")
base_model_id = model.peft_config["default"].base_model_name_or_path base_model_id = model.peft_config["default"].base_model_name_or_path