Load local peft model and add print stmts

This commit is contained in:
Traun Leyden 2023-11-14 12:46:53 +01:00
parent a5def7c222
commit 3be198c698
3 changed files with 24 additions and 0 deletions

View File

@ -123,6 +123,7 @@ def download_weights(
"WEIGHTS_CACHE_OVERRIDE", None
) is not None
print(f"is_local_model: {is_local_model}")
if not is_local_model:
try:
adapter_config_filename = hf_hub_download(
@ -149,6 +150,23 @@ def download_weights(
# Check if we want to automatically convert to safetensors or if we can use .bin weights instead
if not extension == ".safetensors" or not auto_convert:
raise e
# Try to load as a PEFT model
# Newly added
try:
# adapter_config_filename = hf_hub_download(
# model_id, revision=revision, filename="adapter_config.json"
# )
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
utils.weight_files(model_id, revision, extension)
return
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass
# Try to see if there are local pytorch weights
try:

View File

@ -83,8 +83,11 @@ def weight_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[Path]:
"""Get the local files"""
print(f"weight_files called with model_id: {model_id} revision: {revision} extension: {extension}")
# Local model
if Path(model_id).exists() and Path(model_id).is_dir():
print(f"Finding local files with extension: {extension}")
local_files = list(Path(model_id).glob(f"*{extension}"))
if not local_files:
raise FileNotFoundError(

View File

@ -8,6 +8,8 @@ from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
def download_and_unload_peft(model_id, revision, trust_remote_code):
print(f"download_and_unload_peft called with model_id: {model_id} revision: {revision} tmc: {trust_remote_code}")
torch_dtype = torch.float16
logger.info("Peft model detected.")
@ -35,6 +37,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
model = model.merge_and_unload()
print(f"Creating dir: {model_id}")
os.makedirs(model_id, exist_ok=True)
cache_dir = model_id
logger.info(f"Saving the newly created merged model to {cache_dir}")