Improve Transformers support (#2970)

* Much better support

* add gpt neox

* bump transformers version

* bump version
This commit is contained in:
Cyril Vallez 2025-02-18 19:04:34 +01:00 committed by GitHub
parent 5543fdc765
commit a7448661f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 111 additions and 79 deletions

View File

@ -1,3 +1,3 @@
transformers==4.48.2 transformers==4.49
huggingface-hub==0.28.1 huggingface-hub==0.28.1
hf-transfer==0.1.9 hf-transfer==0.1.9

View File

@ -346,7 +346,7 @@ tqdm==4.66.5
# outlines # outlines
# peft # peft
# transformers # transformers
transformers==4.48.2 transformers==4.49
# via # via
# text-generation-server (pyproject.toml) # text-generation-server (pyproject.toml)
# compressed-tensors # compressed-tensors

View File

@ -158,7 +158,7 @@ tqdm==4.67.1
# via # via
# huggingface-hub # huggingface-hub
# transformers # transformers
transformers==4.48.2 transformers==4.49
# via text-generation-server (pyproject.toml) # via text-generation-server (pyproject.toml)
typer==0.15.1 typer==0.15.1
# via text-generation-server (pyproject.toml) # via text-generation-server (pyproject.toml)

View File

@ -331,7 +331,7 @@ tqdm==4.66.5
# outlines # outlines
# peft # peft
# transformers # transformers
transformers==4.48.2 transformers==4.49
# via # via
# text-generation-server (pyproject.toml) # text-generation-server (pyproject.toml)
# compressed-tensors # compressed-tensors

View File

@ -331,7 +331,7 @@ tqdm==4.66.5
# outlines # outlines
# peft # peft
# transformers # transformers
transformers==4.48.2 transformers==4.49
# via # via
# text-generation-server (pyproject.toml) # text-generation-server (pyproject.toml)
# compressed-tensors # compressed-tensors

View File

@ -6,17 +6,18 @@ from compressed_tensors.compressors.model_compressors.model_compressor import (
) )
from compressed_tensors.quantization import QuantizationType from compressed_tensors.quantization import QuantizationType
from pydantic import ValidationError from pydantic import ValidationError
import torch
import enum import enum
import os import os
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List, Dict from typing import Optional, List, Dict
from pathlib import Path from pathlib import Path
from loguru import logger
import torch
import transformers import transformers
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from huggingface_hub import hf_hub_download, HfApi
from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
@ -736,7 +737,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
) )
else: else:
return transformers_causal_lm_class.fallback( return CausalLM.fallback(
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
@ -857,6 +858,15 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig, config_class=GPTNeoXConfig,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
return CausalLM( return CausalLM(
model_id=model_id, model_id=model_id,
@ -1054,6 +1064,15 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
) )
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else: else:
@ -1467,17 +1486,26 @@ def get_model(
elif quantize == "exl2": elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel") raise NotImplementedError("exl2 quantization is not supported for AutoModel")
# Fast transformers if available auto_map = config_dict.get("auto_map", None)
transformers_model_class = getattr( model_class = None
transformers,
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""), # If the model is already in the library
None, if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model_class = getattr(
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
) )
if ( elif (
FLASH_TRANSFORMERS_BACKEND trust_remote_code
and transformers_model_class is not None and auto_map is not None
and transformers_model_class._supports_flex_attn and "AutoModelForCausalLM" in auto_map.keys()
): ):
model_class = get_class_from_dynamic_module(
config_dict["auto_map"]["AutoModelForCausalLM"], model_id
)
# This means the model is ForCausalLM
if model_class is not None:
if FLASH_TRANSFORMERS_BACKEND and model_class.is_backend_compatible():
return TransformersFlashCausalLM.fallback( return TransformersFlashCausalLM.fallback(
model_id, model_id,
revision, revision,
@ -1486,23 +1514,9 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded:
if sharded:
raise NotImplementedError("sharded is not supported for AutoModel") raise NotImplementedError("sharded is not supported for AutoModel")
else:
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM.fallback( return CausalLM.fallback(
model_id, model_id,
revision, revision,
@ -1511,7 +1525,17 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if "AutoModelForSeq2SeqLM" in auto_map.keys():
# Not supported at this point
if sharded:
raise NotImplementedError("sharded is not supported for AutoModel")
# This means it is a ForSeq2SeqLM model
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES or (
trust_remote_code
and auto_map is not None
and "AutoModelForSeq2SeqLM" in auto_map.keys()
):
return Seq2SeqLM.fallback( return Seq2SeqLM.fallback(
model_id, model_id,
revision, revision,

View File

@ -81,6 +81,15 @@ def tgi_flash_attention_forward(
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,
# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache
# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due
# to internal constraints it was not (yet?) possible to circumvent
REPLICATED_ATTENTION_MODELS = [
"olmo2",
"phi3",
]
class TransformersFlashCausalLM(FlashCausalLM): class TransformersFlashCausalLM(FlashCausalLM):
def __init__( def __init__(
@ -119,6 +128,7 @@ class TransformersFlashCausalLM(FlashCausalLM):
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
@ -130,6 +140,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
tp_plan="auto" if world_size > 1 else None, tp_plan="auto" if world_size > 1 else None,
) )
torch.distributed.barrier(group=self.process_group)
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None: if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id tokenizer.pad_token_id = model.config.pad_token_id
@ -143,14 +155,18 @@ class TransformersFlashCausalLM(FlashCausalLM):
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.num_layers = model.config.num_hidden_layers self.num_layers = model.config.num_hidden_layers
self.num_heads = model.config.num_attention_heads // self.process_group.size() self.num_heads = model.config.num_attention_heads
self.num_kv_heads = model.config.num_key_value_heads self.num_kv_heads = model.config.num_key_value_heads
self.head_size = model.config.hidden_size // model.config.num_attention_heads
# Skip it for models in the exception list
if model.config.model_type not in REPLICATED_ATTENTION_MODELS:
self.num_heads = self.num_heads // self.process_group.size()
self.num_kv_heads = ( self.num_kv_heads = (
self.num_kv_heads // self.process_group.size() self.num_kv_heads // self.process_group.size()
if self.num_kv_heads > 1 if self.num_kv_heads > 1
else self.num_kv_heads else self.num_kv_heads
) )
self.head_size = model.config.hidden_size // model.config.num_attention_heads
self.cuda_graphs = {} self.cuda_graphs = {}
self.kv_cache = [] self.kv_cache = []
@ -186,7 +202,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
torch.tensor(1.0, device=device), torch.tensor(1.0, device=device),
) )
torch.distributed.barrier(group=self.process_group)
# Skip FlashCausalLM init. # Skip FlashCausalLM init.
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model_id=model_id, model_id=model_id,
@ -204,6 +219,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
self.model.original_forward = self.model.forward self.model.original_forward = self.model.forward
self.model.forward = self._model_forward self.model.forward = self._model_forward
torch.distributed.barrier(group=self.process_group)
@classmethod @classmethod
def fallback( def fallback(
cls, cls,
@ -237,11 +254,16 @@ class TransformersFlashCausalLM(FlashCausalLM):
prefill_cache_indices=None, # not used, but passed to match original signature prefill_cache_indices=None, # not used, but passed to match original signature
adapter_data=None, # not supported, but passed to match original signature adapter_data=None, # not supported, but passed to match original signature
): ):
hidden_states = self.model.model.forward( # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0
# This is equivalent to `self.model.forward`, see the monkey patch in __init__
logits = self.model.original_forward(
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers
past_key_values=None, # we use self.kv_cache instead of transformers cache object past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object
logits_to_keep=logits_to_keep,
return_dict=True, return_dict=True,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
@ -251,20 +273,6 @@ class TransformersFlashCausalLM(FlashCausalLM):
max_s=max_s, max_s=max_s,
kv_head_mapping=self.kv_head_mapping, kv_head_mapping=self.kv_head_mapping,
kv_scales=self.kv_scales, kv_scales=self.kv_scales,
)[0].squeeze(dim=0) ).logits.squeeze(dim=0)
# And compute logits from the lm_head, slicing correctly the indices
# NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules
# To update with full Transformers support asap
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head(hidden_states)
# For Granite while next transformers version is released and we can use `lm_head_indices` natively
if hasattr(self.model.config, "logits_scaling"):
logits = logits / self.model.config.logits_scaling
# For Cohere for similar reasons
elif hasattr(self.model, "logit_scale"):
logits = logits * self.model.logit_scale
return logits, None return logits, None