2023-01-20 11:24:39 +00:00
|
|
|
import torch
|
|
|
|
|
2023-03-24 13:02:14 +00:00
|
|
|
from loguru import logger
|
2023-03-27 07:23:22 +00:00
|
|
|
from transformers.models.auto import modeling_auto
|
2023-12-05 10:12:16 +00:00
|
|
|
from transformers import AutoConfig
|
2023-01-31 17:53:56 +00:00
|
|
|
from typing import Optional
|
|
|
|
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.model import Model
|
|
|
|
from text_generation_server.models.causal_lm import CausalLM
|
2023-12-05 10:12:16 +00:00
|
|
|
from text_generation_server.models.bloom import BLOOM
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.santacoder import SantaCoder
|
2023-01-20 11:24:39 +00:00
|
|
|
|
2023-06-19 07:53:45 +00:00
|
|
|
|
|
|
|
# Disable gradients
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2023-01-31 17:53:56 +00:00
|
|
|
def get_model(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str],
|
2023-12-05 10:12:16 +00:00
|
|
|
dtype: Optional[torch.dtype] = None,
|
2023-01-31 17:53:56 +00:00
|
|
|
) -> Model:
|
2023-12-05 10:12:16 +00:00
|
|
|
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
|
|
|
model_type = config.model_type
|
2023-01-31 17:53:56 +00:00
|
|
|
|
2023-05-15 08:35:20 +00:00
|
|
|
if model_type == "gpt_bigcode":
|
2023-12-05 10:12:16 +00:00
|
|
|
return SantaCoder(model_id, revision, dtype)
|
2023-05-15 08:35:20 +00:00
|
|
|
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type == "bloom":
|
2023-12-05 10:12:16 +00:00
|
|
|
return BLOOM(model_id, revision, dtype)
|
2023-05-30 16:25:19 +00:00
|
|
|
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
2023-12-05 10:12:16 +00:00
|
|
|
return CausalLM(model_id, revision, dtype)
|
2023-03-27 07:23:22 +00:00
|
|
|
|
|
|
|
raise ValueError(f"Unsupported model type {model_type}")
|