2023-01-20 11:24:39 +00:00
|
|
|
import torch
|
|
|
|
|
2023-03-24 13:02:14 +00:00
|
|
|
from loguru import logger
|
2023-12-11 11:46:30 +00:00
|
|
|
from transformers.configuration_utils import PretrainedConfig
|
2023-03-27 07:23:22 +00:00
|
|
|
from transformers.models.auto import modeling_auto
|
2024-02-26 18:49:28 +00:00
|
|
|
from huggingface_hub import hf_hub_download
|
2023-01-31 17:53:56 +00:00
|
|
|
from typing import Optional
|
2024-02-26 18:49:28 +00:00
|
|
|
from pathlib import Path
|
2023-01-31 17:53:56 +00:00
|
|
|
|
2023-12-11 11:46:30 +00:00
|
|
|
# Needed to properly setup habana_frameworks
|
|
|
|
import text_generation_server.habana_quantization_env as hq_env
|
|
|
|
|
|
|
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
2023-03-07 17:52:22 +00:00
|
|
|
from text_generation_server.models.model import Model
|
|
|
|
from text_generation_server.models.causal_lm import CausalLM
|
2023-12-05 10:12:16 +00:00
|
|
|
from text_generation_server.models.bloom import BLOOM
|
2024-06-14 20:36:44 +00:00
|
|
|
from text_generation_server.models.starcoder import StarCoder
|
2023-01-20 11:24:39 +00:00
|
|
|
|
2024-04-26 09:07:27 +00:00
|
|
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
|
|
|
|
2023-06-19 07:53:45 +00:00
|
|
|
|
|
|
|
# Disable gradients
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
2022-10-28 17:24:00 +00:00
|
|
|
|
2023-01-31 17:53:56 +00:00
|
|
|
def get_model(
|
2023-05-23 18:40:39 +00:00
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str],
|
2023-12-11 11:46:30 +00:00
|
|
|
speculate: Optional[int],
|
|
|
|
dtype: Optional[torch.dtype],
|
|
|
|
trust_remote_code: bool,
|
2023-01-31 17:53:56 +00:00
|
|
|
) -> Model:
|
2024-04-26 09:07:27 +00:00
|
|
|
adapt_transformers_to_gaudi()
|
2024-04-29 06:44:45 +00:00
|
|
|
|
2023-12-11 11:46:30 +00:00
|
|
|
if speculate is not None:
|
|
|
|
set_speculate(speculate)
|
|
|
|
else:
|
|
|
|
set_speculate(0)
|
|
|
|
|
|
|
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
|
|
)
|
|
|
|
|
|
|
|
use_medusa = None
|
|
|
|
if "medusa_num_heads" in config_dict:
|
2024-02-26 18:49:28 +00:00
|
|
|
medusa_model_id = model_id
|
|
|
|
medusa_revision = revision
|
2023-12-11 11:46:30 +00:00
|
|
|
model_id = config_dict["base_model_name_or_path"]
|
|
|
|
revision = "main"
|
|
|
|
speculate_medusa = config_dict["medusa_num_heads"]
|
|
|
|
if speculate is not None:
|
|
|
|
if speculate > speculate_medusa:
|
2023-12-11 13:49:52 +00:00
|
|
|
raise RuntimeError(
|
2024-04-12 14:24:45 +00:00
|
|
|
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
2023-12-11 13:49:52 +00:00
|
|
|
)
|
2023-12-11 11:46:30 +00:00
|
|
|
else:
|
|
|
|
set_speculate(speculate)
|
|
|
|
else:
|
|
|
|
set_speculate(speculate_medusa)
|
|
|
|
|
|
|
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
|
|
)
|
2024-02-26 18:49:28 +00:00
|
|
|
is_local = Path(medusa_model_id).exists()
|
|
|
|
if not is_local:
|
|
|
|
medusa_config = hf_hub_download(
|
|
|
|
medusa_model_id, revision=medusa_revision, filename="config.json"
|
|
|
|
)
|
|
|
|
hf_hub_download(
|
|
|
|
medusa_model_id,
|
|
|
|
revision=medusa_revision,
|
|
|
|
filename="medusa_lm_head.safetensors",
|
|
|
|
)
|
|
|
|
use_medusa = Path(medusa_config).parent
|
|
|
|
else:
|
|
|
|
use_medusa = Path(medusa_model_id)
|
|
|
|
|
2023-12-11 11:46:30 +00:00
|
|
|
method = "medusa"
|
|
|
|
else:
|
|
|
|
method = "n-gram"
|
|
|
|
|
|
|
|
speculate = get_speculate()
|
|
|
|
if speculate > 0:
|
|
|
|
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
|
|
|
|
|
|
|
model_type = config_dict["model_type"]
|
2023-01-31 17:53:56 +00:00
|
|
|
|
2023-05-15 08:35:20 +00:00
|
|
|
if model_type == "gpt_bigcode":
|
2024-06-14 20:36:44 +00:00
|
|
|
return StarCoder(model_id, revision, dtype)
|
2023-05-15 08:35:20 +00:00
|
|
|
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type == "bloom":
|
2024-02-26 18:49:28 +00:00
|
|
|
return BLOOM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
use_medusa=use_medusa,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-05-30 16:25:19 +00:00
|
|
|
|
2023-03-27 07:23:22 +00:00
|
|
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
2024-02-26 18:49:28 +00:00
|
|
|
return CausalLM(
|
|
|
|
model_id,
|
|
|
|
revision,
|
|
|
|
use_medusa=use_medusa,
|
|
|
|
dtype=dtype,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
)
|
2023-03-27 07:23:22 +00:00
|
|
|
|
|
|
|
raise ValueError(f"Unsupported model type {model_type}")
|