mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
* Working loading state. * Preprocessing. * Working state ? (Broke idefics1 temporarily). * Cleaner condition. * Fix idefics. * Updating config, removing TODO * Mllama * Ugrade transformers 4.45 * Flashing mllama. * Starting to get there. * Working state. * Integrations tests for mllama (cutting to 10 tokens because there seems' to be instability after (meaning size of the batch matters. * Updating model link. * Earlier assert. * Fix vlm ? * remove log. * Force ignore all images but last. * Default dtype bfloat16. * Update integration test after switch to bf16. * Remove dead code. * Removed dead code. * Upgrade the flake to latest transformers/tokenizers * Move to hf tgi-nix * Upgrade to 0.5.0
338 lines
12 KiB
Python
338 lines
12 KiB
Python
import torch
|
|
import os
|
|
|
|
from loguru import logger
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
from transformers.models.auto import modeling_auto
|
|
from huggingface_hub import hf_hub_download, HfApi
|
|
from typing import Optional
|
|
from pathlib import Path
|
|
from typing import Optional, List, Dict
|
|
# Needed to properly setup habana_frameworks
|
|
import text_generation_server.habana_quantization_env as hq_env
|
|
|
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
|
from text_generation_server.models.model import Model
|
|
from text_generation_server.models.causal_lm import CausalLM
|
|
from text_generation_server.models.bloom import BLOOM
|
|
from text_generation_server.models.starcoder import StarCoder
|
|
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
|
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
|
from text_generation_server.models.custom_modeling.llava_next import (
|
|
LlavaNextForConditionalGeneration,
|
|
)
|
|
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
|
from text_generation_server.models.custom_modeling.mllama import (
|
|
MllamaForConditionalGeneration,
|
|
)
|
|
from text_generation_server.utils.adapter import (
|
|
AdapterParameters,
|
|
build_layer_weight_lookup,
|
|
load_and_merge_adapters,
|
|
AdapterInfo,
|
|
)
|
|
|
|
|
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
|
|
|
|
|
# Disable gradients
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
def get_model(
|
|
model_id: str,
|
|
lora_adapter_ids: Optional[List[str]],
|
|
revision: Optional[str],
|
|
sharded: bool,
|
|
quantize: Optional[str],
|
|
speculate: Optional[int],
|
|
dtype: Optional[torch.dtype],
|
|
trust_remote_code: bool,
|
|
max_input_tokens: int,
|
|
) -> Model:
|
|
adapt_transformers_to_gaudi()
|
|
|
|
if speculate is not None:
|
|
set_speculate(speculate)
|
|
else:
|
|
set_speculate(0)
|
|
|
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
)
|
|
model_type = config_dict.get("model_type", None)
|
|
|
|
speculator = None
|
|
if "medusa_num_heads" in config_dict:
|
|
medusa_model_id = model_id
|
|
medusa_revision = revision
|
|
model_id = config_dict["base_model_name_or_path"]
|
|
revision = "main"
|
|
speculate_medusa = config_dict["medusa_num_heads"]
|
|
if speculate is not None:
|
|
if speculate > speculate_medusa:
|
|
raise RuntimeError(
|
|
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
|
)
|
|
else:
|
|
set_speculate(speculate)
|
|
else:
|
|
set_speculate(speculate_medusa)
|
|
|
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
)
|
|
# Reload model type from parent.
|
|
model_type = config_dict.get("model_type", None)
|
|
is_local = Path(medusa_model_id).exists()
|
|
if not is_local:
|
|
medusa_config = hf_hub_download(
|
|
medusa_model_id, revision=medusa_revision, filename="config.json"
|
|
)
|
|
hf_hub_download(
|
|
medusa_model_id,
|
|
revision=medusa_revision,
|
|
filename="medusa_lm_head.safetensors",
|
|
)
|
|
speculator = {
|
|
"path": Path(medusa_config).parent,
|
|
"model_paths": ["medusa_lm_head.safetensors"],
|
|
}
|
|
else:
|
|
speculator = {
|
|
"path": Path(medusa_model_id),
|
|
"model_paths": ["medusa_lm_head.safetensors"],
|
|
}
|
|
|
|
method = "medusa"
|
|
elif model_type == "mlp_speculator":
|
|
mlp_model_id = model_id
|
|
mlp_revision = revision
|
|
model_id = config_dict["base_model_name_or_path"]
|
|
revision = "main"
|
|
speculate_mlp = config_dict["n_predict"]
|
|
if speculate is not None:
|
|
if speculate > speculate_mlp:
|
|
raise RuntimeError(
|
|
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
|
|
)
|
|
else:
|
|
set_speculate(speculate)
|
|
else:
|
|
set_speculate(speculate_mlp)
|
|
|
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
)
|
|
# Reload model type from parent.
|
|
model_type = config_dict.get("model_type", None)
|
|
is_local = Path(mlp_model_id).exists()
|
|
extension = ".safetensors"
|
|
if not is_local:
|
|
mlp_speculator_config = hf_hub_download(
|
|
mlp_model_id, revision=mlp_revision, filename="config.json"
|
|
)
|
|
api = HfApi()
|
|
info = api.model_info(mlp_model_id, revision=mlp_revision)
|
|
filenames = [
|
|
s.rfilename
|
|
for s in info.siblings
|
|
if s.rfilename.endswith(extension)
|
|
and len(s.rfilename.split("/")) == 1
|
|
and "arguments" not in s.rfilename
|
|
and "args" not in s.rfilename
|
|
and "training" not in s.rfilename
|
|
]
|
|
for filename in filenames:
|
|
hf_hub_download(
|
|
mlp_model_id,
|
|
revision=mlp_revision,
|
|
filename=filename,
|
|
)
|
|
speculator_dir_path = Path(mlp_speculator_config).parent
|
|
# if these are downloaded, they get converted to safetensors
|
|
filenames.extend(
|
|
[p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
|
|
)
|
|
speculator = {
|
|
"path": Path(mlp_speculator_config).parent,
|
|
"model_paths": filenames,
|
|
}
|
|
else:
|
|
speculator = Path(mlp_model_id)
|
|
filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
|
|
speculator = {"path": speculator, "model_paths": filenames}
|
|
method = "mlp_speculator"
|
|
else:
|
|
method = "n-gram"
|
|
|
|
speculate = get_speculate()
|
|
if speculate > 0:
|
|
logger.info(f"Using speculation {method} with {speculate} input ids.")
|
|
|
|
model_type = config_dict["model_type"]
|
|
|
|
if model_type == "gpt_bigcode":
|
|
return StarCoder(model_id, revision, dtype)
|
|
|
|
if model_type == "bloom":
|
|
return BLOOM(
|
|
model_id,
|
|
revision,
|
|
speculator=speculator,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
if model_type == "llava_next":
|
|
return VlmCausalLM(
|
|
model_class=LlavaNextForConditionalGeneration,
|
|
model_id=model_id,
|
|
revision=revision,
|
|
quantize=None,
|
|
speculator=speculator,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
if model_type == "mllama":
|
|
return MllamaCausalLM(
|
|
model_id=model_id,
|
|
model_class=MllamaForConditionalGeneration,
|
|
batch_class=MllamaCausalLMBatch,
|
|
revision=revision,
|
|
quantize=quantize,
|
|
speculator=speculator,
|
|
dtype=dtype,
|
|
default_dtype=torch.bfloat16,
|
|
trust_remote_code=trust_remote_code,
|
|
lora_adapter_ids=lora_adapter_ids,
|
|
)
|
|
|
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
|
return CausalLM(
|
|
model_id,
|
|
revision,
|
|
quantize=quantize,
|
|
speculator=speculator,
|
|
dtype=dtype,
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
raise ValueError(f"Unsupported model type {model_type}")
|
|
|
|
|
|
# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
|
|
# this provides a post model loading hook to load adapters into the model after the model has been loaded
|
|
def get_model_with_lora_adapters(
|
|
model_id: str,
|
|
lora_adapters: Optional[List[AdapterInfo]],
|
|
revision: Optional[str],
|
|
sharded: bool,
|
|
quantize: Optional[str],
|
|
speculate: Optional[int],
|
|
dtype: Optional[torch.dtype],
|
|
trust_remote_code: bool,
|
|
max_input_tokens: int,
|
|
adapter_to_index: Dict[str, int],
|
|
):
|
|
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
|
|
model = get_model(
|
|
model_id,
|
|
lora_adapter_ids,
|
|
revision,
|
|
sharded,
|
|
quantize,
|
|
speculate,
|
|
dtype,
|
|
trust_remote_code,
|
|
max_input_tokens,
|
|
)
|
|
|
|
if len(lora_adapters) > 0:
|
|
target_to_layer = build_layer_weight_lookup(model.model)
|
|
|
|
for index, adapter in enumerate(lora_adapters):
|
|
# The AdapterParameters object allows for merging multiple adapters into a single adapter.
|
|
# At the moment, we only support loading a single adapter into the model, but we keep the
|
|
# AdapterParameters object for easier extension in the future.
|
|
adapter_parameters = AdapterParameters(
|
|
adapter_info=[adapter],
|
|
# when merging multiple adapters we can weight them differently
|
|
# if this is not set, all adapters will be weighted equally
|
|
# see: text_generation_server.utils.merges.strategies for impl
|
|
weights=None,
|
|
merge_strategy=0,
|
|
density=1.0,
|
|
majority_sign_method=0,
|
|
)
|
|
|
|
adapter_index = index + 1
|
|
adapter_to_index[adapter.id] = adapter_index
|
|
|
|
logger.info(
|
|
f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
|
|
)
|
|
weight_names = tuple([v[0] for v in target_to_layer.values()])
|
|
(
|
|
module_map,
|
|
adapter_config,
|
|
adapter_weight_names,
|
|
adapter_tokenizer,
|
|
) = load_and_merge_adapters(
|
|
model.model_id,
|
|
adapter_parameters,
|
|
adapter_index,
|
|
weight_names,
|
|
False,
|
|
)
|
|
|
|
unused_weight_names = adapter_weight_names.copy()
|
|
|
|
adapter_layers = [
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
"o_proj",
|
|
"gate_proj",
|
|
"up_proj",
|
|
"down_proj",
|
|
"qkv_proj",
|
|
]
|
|
|
|
for layer_name in adapter_layers:
|
|
nlayers = (
|
|
1 if layer_name == "lm_head" else len(model.model.model.layers)
|
|
)
|
|
adapter_weights = LoraWeights.prepare_weights(
|
|
config=adapter_config,
|
|
module_map=module_map,
|
|
layer_type=layer_name,
|
|
unused_weight_names=unused_weight_names,
|
|
nlayers=nlayers,
|
|
dtype=model.dtype,
|
|
world_size=model.world_size,
|
|
process_group=model.process_group,
|
|
target_to_layer=target_to_layer,
|
|
)
|
|
|
|
if adapter_weights is None:
|
|
continue
|
|
|
|
model.layer_to_adapter_weights[layer_name].add_adapter(
|
|
adapter_index, adapter_weights
|
|
)
|
|
|
|
if len(unused_weight_names) > 0:
|
|
logger.warning(
|
|
f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
|
|
)
|
|
|
|
if adapter_tokenizer is not None:
|
|
model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
|
|
|
|
model.loaded_adapters.add(adapter_index)
|
|
|
|
return model
|