mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
Improve Transformers support (#2970)
* Much better support * add gpt neox * bump transformers version * bump version
This commit is contained in:
parent
5543fdc765
commit
a7448661f7
@ -1,3 +1,3 @@
|
|||||||
transformers==4.48.2
|
transformers==4.49
|
||||||
huggingface-hub==0.28.1
|
huggingface-hub==0.28.1
|
||||||
hf-transfer==0.1.9
|
hf-transfer==0.1.9
|
||||||
|
@ -346,7 +346,7 @@ tqdm==4.66.5
|
|||||||
# outlines
|
# outlines
|
||||||
# peft
|
# peft
|
||||||
# transformers
|
# transformers
|
||||||
transformers==4.48.2
|
transformers==4.49
|
||||||
# via
|
# via
|
||||||
# text-generation-server (pyproject.toml)
|
# text-generation-server (pyproject.toml)
|
||||||
# compressed-tensors
|
# compressed-tensors
|
||||||
|
@ -158,7 +158,7 @@ tqdm==4.67.1
|
|||||||
# via
|
# via
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# transformers
|
# transformers
|
||||||
transformers==4.48.2
|
transformers==4.49
|
||||||
# via text-generation-server (pyproject.toml)
|
# via text-generation-server (pyproject.toml)
|
||||||
typer==0.15.1
|
typer==0.15.1
|
||||||
# via text-generation-server (pyproject.toml)
|
# via text-generation-server (pyproject.toml)
|
||||||
|
@ -331,7 +331,7 @@ tqdm==4.66.5
|
|||||||
# outlines
|
# outlines
|
||||||
# peft
|
# peft
|
||||||
# transformers
|
# transformers
|
||||||
transformers==4.48.2
|
transformers==4.49
|
||||||
# via
|
# via
|
||||||
# text-generation-server (pyproject.toml)
|
# text-generation-server (pyproject.toml)
|
||||||
# compressed-tensors
|
# compressed-tensors
|
||||||
|
@ -331,7 +331,7 @@ tqdm==4.66.5
|
|||||||
# outlines
|
# outlines
|
||||||
# peft
|
# peft
|
||||||
# transformers
|
# transformers
|
||||||
transformers==4.48.2
|
transformers==4.49
|
||||||
# via
|
# via
|
||||||
# text-generation-server (pyproject.toml)
|
# text-generation-server (pyproject.toml)
|
||||||
# compressed-tensors
|
# compressed-tensors
|
||||||
|
@ -6,17 +6,18 @@ from compressed_tensors.compressors.model_compressors.model_compressor import (
|
|||||||
)
|
)
|
||||||
from compressed_tensors.quantization import QuantizationType
|
from compressed_tensors.quantization import QuantizationType
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
import torch
|
|
||||||
import enum
|
import enum
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
from transformers.models.auto import modeling_auto
|
|
||||||
from huggingface_hub import hf_hub_download, HfApi
|
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.models.auto import modeling_auto
|
||||||
|
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
||||||
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
|
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
@ -736,7 +737,7 @@ def get_model(
|
|||||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return transformers_causal_lm_class.fallback(
|
return CausalLM.fallback(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
@ -857,6 +858,15 @@ def get_model(
|
|||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
config_class=GPTNeoXConfig,
|
config_class=GPTNeoXConfig,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -1054,6 +1064,15 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
lora_adapter_ids=lora_adapter_ids,
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
)
|
)
|
||||||
|
elif FLASH_TRANSFORMERS_BACKEND:
|
||||||
|
return TransformersFlashCausalLM.fallback(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||||
else:
|
else:
|
||||||
@ -1467,17 +1486,26 @@ def get_model(
|
|||||||
elif quantize == "exl2":
|
elif quantize == "exl2":
|
||||||
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
||||||
|
|
||||||
# Fast transformers if available
|
auto_map = config_dict.get("auto_map", None)
|
||||||
transformers_model_class = getattr(
|
model_class = None
|
||||||
transformers,
|
|
||||||
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
|
# If the model is already in the library
|
||||||
None,
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
|
model_class = getattr(
|
||||||
|
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
||||||
)
|
)
|
||||||
if (
|
elif (
|
||||||
FLASH_TRANSFORMERS_BACKEND
|
trust_remote_code
|
||||||
and transformers_model_class is not None
|
and auto_map is not None
|
||||||
and transformers_model_class._supports_flex_attn
|
and "AutoModelForCausalLM" in auto_map.keys()
|
||||||
):
|
):
|
||||||
|
model_class = get_class_from_dynamic_module(
|
||||||
|
config_dict["auto_map"]["AutoModelForCausalLM"], model_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# This means the model is ForCausalLM
|
||||||
|
if model_class is not None:
|
||||||
|
if FLASH_TRANSFORMERS_BACKEND and model_class.is_backend_compatible():
|
||||||
return TransformersFlashCausalLM.fallback(
|
return TransformersFlashCausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -1486,23 +1514,9 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
elif sharded:
|
||||||
if sharded:
|
|
||||||
raise NotImplementedError("sharded is not supported for AutoModel")
|
raise NotImplementedError("sharded is not supported for AutoModel")
|
||||||
|
else:
|
||||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
|
||||||
return Seq2SeqLM.fallback(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
speculator=speculator,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
auto_map = config_dict.get("auto_map", None)
|
|
||||||
if trust_remote_code and auto_map is not None:
|
|
||||||
if "AutoModelForCausalLM" in auto_map.keys():
|
|
||||||
return CausalLM.fallback(
|
return CausalLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -1511,7 +1525,17 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
|
||||||
|
# Not supported at this point
|
||||||
|
if sharded:
|
||||||
|
raise NotImplementedError("sharded is not supported for AutoModel")
|
||||||
|
|
||||||
|
# This means it is a ForSeq2SeqLM model
|
||||||
|
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES or (
|
||||||
|
trust_remote_code
|
||||||
|
and auto_map is not None
|
||||||
|
and "AutoModelForSeq2SeqLM" in auto_map.keys()
|
||||||
|
):
|
||||||
return Seq2SeqLM.fallback(
|
return Seq2SeqLM.fallback(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -81,6 +81,15 @@ def tgi_flash_attention_forward(
|
|||||||
|
|
||||||
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
|
||||||
|
|
||||||
|
# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,
|
||||||
|
# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache
|
||||||
|
# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due
|
||||||
|
# to internal constraints it was not (yet?) possible to circumvent
|
||||||
|
REPLICATED_ATTENTION_MODELS = [
|
||||||
|
"olmo2",
|
||||||
|
"phi3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TransformersFlashCausalLM(FlashCausalLM):
|
class TransformersFlashCausalLM(FlashCausalLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -119,6 +128,7 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
@ -130,6 +140,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
tp_plan="auto" if world_size > 1 else None,
|
tp_plan="auto" if world_size > 1 else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
if model.config.pad_token_id is not None:
|
if model.config.pad_token_id is not None:
|
||||||
tokenizer.pad_token_id = model.config.pad_token_id
|
tokenizer.pad_token_id = model.config.pad_token_id
|
||||||
@ -143,14 +155,18 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
self.num_layers = model.config.num_hidden_layers
|
self.num_layers = model.config.num_hidden_layers
|
||||||
self.num_heads = model.config.num_attention_heads // self.process_group.size()
|
self.num_heads = model.config.num_attention_heads
|
||||||
self.num_kv_heads = model.config.num_key_value_heads
|
self.num_kv_heads = model.config.num_key_value_heads
|
||||||
|
self.head_size = model.config.hidden_size // model.config.num_attention_heads
|
||||||
|
|
||||||
|
# Skip it for models in the exception list
|
||||||
|
if model.config.model_type not in REPLICATED_ATTENTION_MODELS:
|
||||||
|
self.num_heads = self.num_heads // self.process_group.size()
|
||||||
self.num_kv_heads = (
|
self.num_kv_heads = (
|
||||||
self.num_kv_heads // self.process_group.size()
|
self.num_kv_heads // self.process_group.size()
|
||||||
if self.num_kv_heads > 1
|
if self.num_kv_heads > 1
|
||||||
else self.num_kv_heads
|
else self.num_kv_heads
|
||||||
)
|
)
|
||||||
self.head_size = model.config.hidden_size // model.config.num_attention_heads
|
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
self.cuda_graphs = {}
|
||||||
self.kv_cache = []
|
self.kv_cache = []
|
||||||
@ -186,7 +202,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
torch.tensor(1.0, device=device),
|
torch.tensor(1.0, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
|
||||||
# Skip FlashCausalLM init.
|
# Skip FlashCausalLM init.
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -204,6 +219,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
self.model.original_forward = self.model.forward
|
self.model.original_forward = self.model.forward
|
||||||
self.model.forward = self._model_forward
|
self.model.forward = self._model_forward
|
||||||
|
|
||||||
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fallback(
|
def fallback(
|
||||||
cls,
|
cls,
|
||||||
@ -237,11 +254,16 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
prefill_cache_indices=None, # not used, but passed to match original signature
|
prefill_cache_indices=None, # not used, but passed to match original signature
|
||||||
adapter_data=None, # not supported, but passed to match original signature
|
adapter_data=None, # not supported, but passed to match original signature
|
||||||
):
|
):
|
||||||
hidden_states = self.model.model.forward(
|
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
|
||||||
|
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
|
||||||
|
|
||||||
|
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
|
||||||
|
logits = self.model.original_forward(
|
||||||
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
|
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
|
||||||
position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers
|
position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers
|
||||||
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
@ -251,20 +273,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
kv_head_mapping=self.kv_head_mapping,
|
kv_head_mapping=self.kv_head_mapping,
|
||||||
kv_scales=self.kv_scales,
|
kv_scales=self.kv_scales,
|
||||||
)[0].squeeze(dim=0)
|
).logits.squeeze(dim=0)
|
||||||
|
|
||||||
# And compute logits from the lm_head, slicing correctly the indices
|
|
||||||
# NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules
|
|
||||||
# To update with full Transformers support asap
|
|
||||||
if lm_head_indices is not None:
|
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
|
||||||
logits = self.model.lm_head(hidden_states)
|
|
||||||
|
|
||||||
# For Granite while next transformers version is released and we can use `lm_head_indices` natively
|
|
||||||
if hasattr(self.model.config, "logits_scaling"):
|
|
||||||
logits = logits / self.model.config.logits_scaling
|
|
||||||
# For Cohere for similar reasons
|
|
||||||
elif hasattr(self.model, "logit_scale"):
|
|
||||||
logits = logits * self.model.logit_scale
|
|
||||||
|
|
||||||
return logits, None
|
return logits, None
|
||||||
|
Loading…
Reference in New Issue
Block a user