mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-03 16:22:06 +00:00
fix(server): Handle loading from local files for MPT (#534)
This PR allows the MPT model to be loaded from local files. Without this change, an exception will be thrown by `hf_hub_download` function if `model_id` is a local path.
This commit is contained in:
parent
e6888d0e87
commit
2a101207d4
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
|
||||||
@ -60,7 +61,12 @@ class MPTSharded(CausalLM):
|
|||||||
)
|
)
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
|
# If model_id is a local path, load the file directly
|
||||||
|
local_path = Path(model_id, "config.json")
|
||||||
|
if local_path.exists():
|
||||||
|
filename = str(local_path.resolve())
|
||||||
|
else:
|
||||||
|
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
config = PretrainedConfig(**config)
|
config = PretrainedConfig(**config)
|
||||||
|
Loading…
Reference in New Issue
Block a user