mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Load local peft model and add print stmts
This commit is contained in:
parent
a5def7c222
commit
3be198c698
@ -123,6 +123,7 @@ def download_weights(
|
||||
"WEIGHTS_CACHE_OVERRIDE", None
|
||||
) is not None
|
||||
|
||||
print(f"is_local_model: {is_local_model}")
|
||||
if not is_local_model:
|
||||
try:
|
||||
adapter_config_filename = hf_hub_download(
|
||||
@ -150,6 +151,23 @@ def download_weights(
|
||||
if not extension == ".safetensors" or not auto_convert:
|
||||
raise e
|
||||
|
||||
# Try to load as a PEFT model
|
||||
# Newly added
|
||||
try:
|
||||
|
||||
# adapter_config_filename = hf_hub_download(
|
||||
# model_id, revision=revision, filename="adapter_config.json"
|
||||
# )
|
||||
|
||||
utils.download_and_unload_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
return
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
|
||||
|
||||
# Try to see if there are local pytorch weights
|
||||
try:
|
||||
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
|
||||
|
@ -83,8 +83,11 @@ def weight_files(
|
||||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||
) -> List[Path]:
|
||||
"""Get the local files"""
|
||||
print(f"weight_files called with model_id: {model_id} revision: {revision} extension: {extension}")
|
||||
|
||||
# Local model
|
||||
if Path(model_id).exists() and Path(model_id).is_dir():
|
||||
print(f"Finding local files with extension: {extension}")
|
||||
local_files = list(Path(model_id).glob(f"*{extension}"))
|
||||
if not local_files:
|
||||
raise FileNotFoundError(
|
||||
|
@ -8,6 +8,8 @@ from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
|
||||
|
||||
|
||||
def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||
print(f"download_and_unload_peft called with model_id: {model_id} revision: {revision} tmc: {trust_remote_code}")
|
||||
|
||||
torch_dtype = torch.float16
|
||||
|
||||
logger.info("Peft model detected.")
|
||||
@ -35,6 +37,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||
|
||||
model = model.merge_and_unload()
|
||||
|
||||
print(f"Creating dir: {model_id}")
|
||||
os.makedirs(model_id, exist_ok=True)
|
||||
cache_dir = model_id
|
||||
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
||||
|
Loading…
Reference in New Issue
Block a user