Flash Transformers modeling backend support (#2913)

* add transformers_flash

* inits

* switch version to make it work

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* Update Makefile-flash-att-v2

* runnable version

* working

* push change

* fix high dim

* init

* default

* latest transformers changes

* revert

* simplify check

* remove flag

* improve type hints + required args

* Update based on transformers PR

* small fix

* Remove Warpers for Processor

* fix compatibility version issue

* raise error if needed

* Simplify with monkey patch

* revert + style + minor improvements

* update comment

* device check

* move the import to avoid device issue

* Update __init__.py

* check for non-native models

* oupsi

---------

Co-authored-by: System administrator <root@ip-10-90-0-159.ec2.internal>
This commit is contained in:
Cyril Vallez 2025-01-21 10:01:51 +01:00 committed by GitHub
parent 447a5b2f87
commit b980848abf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 330 additions and 31 deletions

View File

@ -956,15 +956,24 @@ def quantize(
pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file
from transformers.modeling_utils import shard_checkpoint
from huggingface_hub import split_torch_state_dict_into_shards
state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
max_shard_size = "10GB"
shards, index = shard_checkpoint(
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
state_dict_split = split_torch_state_dict_into_shards(
state_dict,
filename_pattern="model.safetensors",
max_shard_size=max_shard_size,
)
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
shards = state_dict_split.filename_to_tensors
os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items():
save_file(

View File

@ -16,10 +16,12 @@ from transformers.models.auto import modeling_auto
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List, Dict
from pathlib import Path
import transformers
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
@ -178,6 +180,14 @@ except ImportError as e:
if MAMBA_AVAILABLE:
__all__.append(Mamba)
FLASH_TRANSFORMERS_BACKEND = True
try:
from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM,
)
except ImportError:
FLASH_TRANSFORMERS_BACKEND = False
class ModelType(enum.Enum):
DEEPSEEK_V2 = {
@ -381,6 +391,21 @@ def get_model(
)
model_type = config_dict.get("model_type", None)
transformers_causal_lm_class = CausalLM
# Fast transformers path
transformers_model_class = getattr(
transformers,
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
None,
)
if (
FLASH_TRANSFORMERS_BACKEND
and transformers_model_class is not None
and transformers_model_class._supports_flex_attn
):
transformers_causal_lm_class = TransformersFlashCausalLM
quantization_config = config_dict.get("quantization_config", None)
if quantization_config is None:
quantization_config = config_dict.get("compression_config", None)
@ -624,7 +649,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2")
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -683,7 +708,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id=model_id,
revision=revision,
quantize=quantize,
@ -731,7 +756,7 @@ def get_model(
except RuntimeError as e:
# Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}")
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -742,7 +767,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -767,7 +792,7 @@ def get_model(
except RuntimeError as e:
# Lots of legacy models with various weight names.
log_master(logger.warning, f"Couldn't load flash gptj variant: {e}")
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -778,7 +803,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-J"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -815,7 +840,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -838,7 +863,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -862,7 +887,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -911,7 +936,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded {model_type}")
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -937,7 +962,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -963,7 +988,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -988,7 +1013,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1016,7 +1041,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1066,7 +1091,7 @@ def get_model(
config_class=RWConfig,
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1091,7 +1116,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1116,7 +1141,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1143,7 +1168,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
)
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1168,7 +1193,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1329,7 +1354,7 @@ def get_model(
elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,
@ -1350,7 +1375,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM.fallback(
return transformers_causal_lm_class.fallback(
model_id,
revision,
quantize=quantize,

View File

@ -0,0 +1,266 @@
import math
from typing import List, Optional
import torch
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import transformers.modeling_utils
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.utils import initialize_torch_distributed
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
from text_generation_server.models.globals import ATTENTION
tracer = trace.get_tracer(__name__)
def tgi_flash_attention_forward(
module,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor], # This is a positional arg in Transformers
kv_cache: List[KVCache],
kv_head_mapping: torch.Tensor,
slots: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
seqlen: Seqlen,
block_tables: torch.Tensor,
max_s: int,
kv_scales: KVScales,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
):
kv_cache = kv_cache[module.layer_idx]
query_states = query_states.transpose(1, 2).squeeze(dim=0)
key_states = key_states.transpose(1, 2).squeeze(dim=0)
value_states = value_states.transpose(1, 2).squeeze(dim=0)
# Take care of updating the cache in-place
kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)
_, num_heads, head_dim = query_states.shape
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
sliding_window = -1 if sliding_window is None else sliding_window
if cu_seqlen_prefill is not None:
attn_output = attention(
query=query_states,
key=key_states,
value=value_states,
kv_cache=kv_cache,
kv_scales=kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=softmax_scale,
window_size_left=sliding_window,
softcap=softcap,
)
else:
attn_output = paged_attention(
query_states,
kv_cache,
kv_head_mapping,
softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=kv_scales,
softcap=softcap,
)
attn_output = attn_output.view(-1, num_heads * head_dim)
return attn_output, None
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward
class TransformersFlashCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
config_class=AutoConfig,
kv_cache_dtype: Optional[torch.dtype] = None,
):
self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed()
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
dtype = torch.float16 if dtype is None else dtype
else:
raise ValueError(
"Flash `Transformers` modeling backend is not available on cpu."
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto",
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
attn_implementation="tgi",
tp_plan="auto" if world_size > 1 else None,
)
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None and isinstance(
model.config.eos_token_id, int
):
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.num_layers = model.config.num_hidden_layers
self.num_heads = model.config.num_attention_heads // self.process_group.size()
self.num_kv_heads = model.config.num_key_value_heads
self.num_kv_heads = (
self.num_kv_heads // self.process_group.size()
if self.num_kv_heads > 1
else self.num_kv_heads
)
self.head_size = model.config.hidden_size // model.config.num_attention_heads
self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
create_prefill_state,
create_decode_state,
create_prefill_with_paged_kv_state,
)
self.prefill_state = create_prefill_state(device=device)
self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(
device=device
)
self.decode_state = create_decode_state(
device=device,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
self.num_groups = self.num_heads // self.num_kv_heads
# Those will never change and will be used in the forwards
self.kv_head_mapping = torch.arange(
0, self.num_kv_heads, dtype=torch.int32, device=device
).repeat_interleave(self.num_groups)
# This means no scale
self.kv_scales = KVScales(
torch.tensor(1.0, device=device),
torch.tensor(1.0, device=device),
)
torch.distributed.barrier(group=self.process_group)
# Skip FlashCausalLM init.
super(FlashCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
# We first copy the original model.forward because we still need it in the monkey patch
self.model.original_forward = self.model.forward
self.model.forward = self._model_forward
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
return cls(
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def _model_forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[KVCache],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
lm_head_indices: Optional[torch.Tensor],
prefill_cache_indices=None, # not used, but passed to match original signature
adapter_data=None, # not supported, but passed to match original signature
):
hidden_states = self.model.model.forward(
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers
past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object
return_dict=True,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
kv_head_mapping=self.kv_head_mapping,
kv_scales=self.kv_scales,
)[0].squeeze(dim=0)
# And compute logits from the lm_head, slicing correctly the indices
# NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules
# To update with full Transformers support asap
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head.forward(hidden_states)
return logits, None

View File

@ -5,13 +5,12 @@ import torch
from typing import List, Optional, DefaultDict
from loguru import logger
from typing import Dict, Union
from typing import Dict
from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.guide import RegexGuide
from transformers import (
LogitsWarper,
LogitsProcessor,
PreTrainedTokenizerBase,
TemperatureLogitsWarper,
@ -219,7 +218,7 @@ class HeterogeneousTemperatureLogitsWarper:
return None
class HeterogeneousTopPLogitsWarper(LogitsWarper):
class HeterogeneousTopPLogitsWarper(LogitsProcessor):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
This version allows for a separate value for each sample and runs inplace when possible.
@ -278,7 +277,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
return None
class HeterogeneousTopKLogitsWarper(LogitsWarper):
class HeterogeneousTopKLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
This version allows for a separate value for each sample and runs inplace when possible.
@ -359,7 +358,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
return None
class HeterogeneousTypicalLogitsWarper(LogitsWarper):
class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information.
@ -453,13 +452,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
r"""
A wrapper for logit warpers or processors without heterogeneous parameter support.
Args:
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`):
processors (`Dict[int, LogitsProcessor]`):
A mapping of sample indices to logit warpers or processors, to be run sequentially.
"""
def __init__(
self,
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]],
processors: Dict[int, LogitsProcessor],
):
self.processors = processors