Reuse the same function to list local weights everywhere

Signed-off-by: Raphael Glon <oOraph@users.noreply.github.com>
This commit is contained in:
Raphael Glon 2023-12-15 12:48:05 +01:00
parent 29f87920a8
commit d5b7e6e38f
No known key found for this signature in database
GPG Key ID: 4D4CC6881E12A0C3

View File

@ -125,13 +125,14 @@ def weight_files(
) -> List[Path]: ) -> List[Path]:
"""Get the local files""" """Get the local files"""
# Local model # Local model
if Path(model_id).exists() and Path(model_id).is_dir(): d = Path(model_id)
local_files = list(Path(model_id).glob(f"*{extension}")) if d.exists() and d.is_dir():
local_files = _weight_files_from_dir(d, extension)
if not local_files: if not local_files:
raise FileNotFoundError( raise FileNotFoundError(
f"No local weights found in {model_id} with extension {extension}" f"No local weights found in {model_id} with extension {extension}"
) )
return local_files return [Path(f) for f in local_files]
try: try:
filenames = weight_hub_files(model_id, revision, extension) filenames = weight_hub_files(model_id, revision, extension)