text-generation-inference/server/text_generation_server/models/__init__.py
2024-04-18 16:26:00 +03:00

77 lines
2.4 KiB
Python

import torch
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from typing import Optional
# Needed to properly setup habana_frameworks
import text_generation_server.habana_quantization_env as hq_env
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.santacoder import SantaCoder
# Disable gradients
torch.set_grad_enabled(False)
def get_model(
model_id: str,
revision: Optional[str],
speculate: Optional[int],
dtype: Optional[torch.dtype],
trust_remote_code: bool,
) -> Model:
if speculate is not None:
set_speculate(speculate)
else:
set_speculate(0)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
use_medusa = None
if "medusa_num_heads" in config_dict:
use_medusa = model_id
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError(
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
set_speculate(speculate_medusa)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
method = "medusa"
else:
method = "n-gram"
speculate = get_speculate()
if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict["model_type"]
if model_type == "gpt_bigcode":
return SantaCoder(model_id, revision, dtype)
if model_type == "bloom":
return BLOOM(model_id, revision, dtype)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(model_id, revision, dtype)
raise ValueError(f"Unsupported model type {model_type}")