text-generation-inference/backends/gaudi/server/text_generation_server/utils/peft.py
Baptiste Colle 683ff53fa3
Add Gaudi Backend (#3055)
* wip(gaudi): import server and dockerfile from tgi-gaudi fork

* feat(gaudi): new gaudi backend working

* fix: fix style

* fix prehooks issues

* fix(gaudi): refactor server and implement requested changes
2025-02-28 12:14:58 +01:00

69 lines
2.1 KiB
Python

import os
from typing import Union
from loguru import logger
import torch
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
def download_and_unload_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16
logger.info("Trying to load a Peft model. It might take a while without feedback")
try:
model = AutoPeftModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
except Exception:
model = AutoPeftModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
logger.info("Peft model detected.")
logger.info("Merging the lora weights.")
base_model_id = model.peft_config["default"].base_model_name_or_path
model = model.merge_and_unload()
os.makedirs(model_id, exist_ok=True)
cache_dir = model_id
logger.info(f"Saving the newly created merged model to {cache_dir}")
tokenizer = AutoTokenizer.from_pretrained(
base_model_id, trust_remote_code=trust_remote_code
)
model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir)
def download_peft(
model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool
):
torch_dtype = torch.float16
try:
_model = AutoPeftModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
except Exception:
_model = AutoPeftModelForSeq2SeqLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
)
logger.info("Peft model downloaded.")