mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Load local peft model and add print stmts
This commit is contained in:
parent
a5def7c222
commit
3be198c698
@ -123,6 +123,7 @@ def download_weights(
|
|||||||
"WEIGHTS_CACHE_OVERRIDE", None
|
"WEIGHTS_CACHE_OVERRIDE", None
|
||||||
) is not None
|
) is not None
|
||||||
|
|
||||||
|
print(f"is_local_model: {is_local_model}")
|
||||||
if not is_local_model:
|
if not is_local_model:
|
||||||
try:
|
try:
|
||||||
adapter_config_filename = hf_hub_download(
|
adapter_config_filename = hf_hub_download(
|
||||||
@ -149,6 +150,23 @@ def download_weights(
|
|||||||
# Check if we want to automatically convert to safetensors or if we can use .bin weights instead
|
# Check if we want to automatically convert to safetensors or if we can use .bin weights instead
|
||||||
if not extension == ".safetensors" or not auto_convert:
|
if not extension == ".safetensors" or not auto_convert:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
# Try to load as a PEFT model
|
||||||
|
# Newly added
|
||||||
|
try:
|
||||||
|
|
||||||
|
# adapter_config_filename = hf_hub_download(
|
||||||
|
# model_id, revision=revision, filename="adapter_config.json"
|
||||||
|
# )
|
||||||
|
|
||||||
|
utils.download_and_unload_peft(
|
||||||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
return
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Try to see if there are local pytorch weights
|
# Try to see if there are local pytorch weights
|
||||||
try:
|
try:
|
||||||
|
@ -83,8 +83,11 @@ def weight_files(
|
|||||||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||||
) -> List[Path]:
|
) -> List[Path]:
|
||||||
"""Get the local files"""
|
"""Get the local files"""
|
||||||
|
print(f"weight_files called with model_id: {model_id} revision: {revision} extension: {extension}")
|
||||||
|
|
||||||
# Local model
|
# Local model
|
||||||
if Path(model_id).exists() and Path(model_id).is_dir():
|
if Path(model_id).exists() and Path(model_id).is_dir():
|
||||||
|
print(f"Finding local files with extension: {extension}")
|
||||||
local_files = list(Path(model_id).glob(f"*{extension}"))
|
local_files = list(Path(model_id).glob(f"*{extension}"))
|
||||||
if not local_files:
|
if not local_files:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
|
@ -8,6 +8,8 @@ from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
|
|||||||
|
|
||||||
|
|
||||||
def download_and_unload_peft(model_id, revision, trust_remote_code):
|
def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||||
|
print(f"download_and_unload_peft called with model_id: {model_id} revision: {revision} tmc: {trust_remote_code}")
|
||||||
|
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
logger.info("Peft model detected.")
|
logger.info("Peft model detected.")
|
||||||
@ -35,6 +37,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
|||||||
|
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
|
print(f"Creating dir: {model_id}")
|
||||||
os.makedirs(model_id, exist_ok=True)
|
os.makedirs(model_id, exist_ok=True)
|
||||||
cache_dir = model_id
|
cache_dir = model_id
|
||||||
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
||||||
|
Loading…
Reference in New Issue
Block a user