Handle loading from local files for MPT

This commit is contained in:
Antoni Baum 2023-07-03 12:19:54 -07:00 committed by GitHub
parent 1da07e85aa
commit 5c490fb56a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.distributed import torch.distributed
from pathlib import Path
from typing import Optional, Type from typing import Optional, Type
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
@ -60,6 +61,11 @@ class MPTSharded(CausalLM):
) )
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# If model_id is a local path, load the file directly
local_path = Path(model_id, "config.json")
if local_path.exists():
filename = str(local_path.resolve())
else:
filename = hf_hub_download(model_id, revision=revision, filename="config.json") filename = hf_hub_download(model_id, revision=revision, filename="config.json")
with open(filename, "r") as f: with open(filename, "r") as f:
config = json.load(f) config = json.load(f)