Handle loading from local files for MPT

This commit is contained in:
Antoni Baum 2023-07-03 12:19:54 -07:00 committed by GitHub
parent 1da07e85aa
commit 5c490fb56a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
import torch
import torch.distributed
from pathlib import Path
from typing import Optional, Type
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
@ -60,7 +61,12 @@ class MPTSharded(CausalLM):
)
tokenizer.pad_token = tokenizer.eos_token
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
# If model_id is a local path, load the file directly
local_path = Path(model_id, "config.json")
if local_path.exists():
filename = str(local_path.resolve())
else:
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
with open(filename, "r") as f:
config = json.load(f)
config = PretrainedConfig(**config)