text-generation-inference/server/text_generation_server/models/__init__.py

39 lines
1.1 KiB
Python

import torch
from loguru import logger
from transformers.models.auto import modeling_auto
from transformers import AutoConfig
from typing import Optional
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.santacoder import SantaCoder
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
# Disable gradients
torch.set_grad_enabled(False)
def get_model(
model_id: str,
revision: Optional[str],
dtype: Optional[torch.dtype] = None,
) -> Model:
adapt_transformers_to_gaudi()
config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type
if model_type == "gpt_bigcode":
return SantaCoder(model_id, revision, dtype)
if model_type == "bloom":
return BLOOM(model_id, revision, dtype)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(model_id, revision, dtype)
raise ValueError(f"Unsupported model type {model_type}")