mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Handle loading from local files for MPT
This commit is contained in:
parent
1da07e85aa
commit
5c490fb56a
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
|
||||
@ -60,7 +61,12 @@ class MPTSharded(CausalLM):
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
|
||||
# If model_id is a local path, load the file directly
|
||||
local_path = Path(model_id, "config.json")
|
||||
if local_path.exists():
|
||||
filename = str(local_path.resolve())
|
||||
else:
|
||||
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
|
||||
with open(filename, "r") as f:
|
||||
config = json.load(f)
|
||||
config = PretrainedConfig(**config)
|
||||
|
Loading…
Reference in New Issue
Block a user