diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 87fc1f07..3c0f8167 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -1,6 +1,7 @@ import torch import torch.distributed +from pathlib import Path from typing import Optional, Type from opentelemetry import trace from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase @@ -60,7 +61,12 @@ class MPTSharded(CausalLM): ) tokenizer.pad_token = tokenizer.eos_token - filename = hf_hub_download(model_id, revision=revision, filename="config.json") + # If model_id is a local path, load the file directly + local_path = Path(model_id, "config.json") + if local_path.exists(): + filename = str(local_path.resolve()) + else: + filename = hf_hub_download(model_id, revision=revision, filename="config.json") with open(filename, "r") as f: config = json.load(f) config = PretrainedConfig(**config)