From 2a101207d44b903c1cc9b4d968a4b24150413942 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 4 Jul 2023 09:37:25 -0700 Subject: [PATCH] fix(server): Handle loading from local files for MPT (#534) This PR allows the MPT model to be loaded from local files. Without this change, an exception will be thrown by `hf_hub_download` function if `model_id` is a local path. --- server/text_generation_server/models/mpt.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 87fc1f07..3c0f8167 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -1,6 +1,7 @@ import torch import torch.distributed +from pathlib import Path from typing import Optional, Type from opentelemetry import trace from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase @@ -60,7 +61,12 @@ class MPTSharded(CausalLM): ) tokenizer.pad_token = tokenizer.eos_token - filename = hf_hub_download(model_id, revision=revision, filename="config.json") + # If model_id is a local path, load the file directly + local_path = Path(model_id, "config.json") + if local_path.exists(): + filename = str(local_path.resolve()) + else: + filename = hf_hub_download(model_id, revision=revision, filename="config.json") with open(filename, "r") as f: config = json.load(f) config = PretrainedConfig(**config)