fix(server): Handle loading from local files for MPT (#534)

This PR allows the MPT model to be loaded from local files. Without this
change, an exception will be thrown by `hf_hub_download` function if
`model_id` is a local path.
This commit is contained in:
Antoni Baum 2023-07-04 09:37:25 -07:00 committed by GitHub
parent e6888d0e87
commit 2a101207d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.distributed import torch.distributed
from pathlib import Path
from typing import Optional, Type from typing import Optional, Type
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
@ -60,7 +61,12 @@ class MPTSharded(CausalLM):
) )
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
filename = hf_hub_download(model_id, revision=revision, filename="config.json") # If model_id is a local path, load the file directly
local_path = Path(model_id, "config.json")
if local_path.exists():
filename = str(local_path.resolve())
else:
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
with open(filename, "r") as f: with open(filename, "r") as f:
config = json.load(f) config = json.load(f)
config = PretrainedConfig(**config) config = PretrainedConfig(**config)