From ade0f44aca75f5a5e9dd29141fb57013070e15fc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 10 Dec 2024 16:46:55 +0100 Subject: [PATCH] add transformers_flash --- .../text_generation_server/models/__init__.py | 29 +- .../text_generation_server/models/globals.py | 4 + .../models/transformers_flash_causal_lm.py | 309 ++++++++++++++++++ 3 files changed, 341 insertions(+), 1 deletion(-) create mode 100644 server/text_generation_server/models/transformers_flash_causal_lm.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fcc79608..2f3ccc2d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -20,6 +20,7 @@ from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast +from text_generation_server.models.transformers_flash_causal_lm import TransformersFlashCausalLM from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, @@ -28,7 +29,7 @@ from text_generation_server.models.bloom import BloomCausalLMBatch from text_generation_server.models.custom_modeling.bloom_modeling import ( BloomForCausalLM, ) -from text_generation_server.models.globals import ATTENTION +from text_generation_server.models.globals import ATTENTION, USE_CUSTOM_MODELING from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.galactica import GalacticaCausalLMBatch from text_generation_server.models.custom_modeling.neox_modeling import ( @@ -366,12 +367,38 @@ def get_model( max_input_tokens: int, ) -> Model: global FLASH_ATTENTION + global USE_CUSTOM_MODELING config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) model_type = config_dict.get("model_type", None) + transformers_causal_lm_class = CausalLM + if ( + not USE_CUSTOM_MODELING + and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + ): + logger.info( + "TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback." + ) + transformers_model_class = getattr( + transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] + ) + + if ( + transformers_model_class._supports_flash_attn_2 + and transformers_model_class._supports_cache_class + ): + logger.info( + f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersFlashCausalLM with ragged tensors (single dimension for batch and sequence length)." + ) + transformers_causal_lm_class = TransformersFlashCausalLM + else: + logger.info( + f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersCausalLM with classic tensors with padding (two dimensions for batch size and sequence length)." + ) + quantization_config = config_dict.get("quantization_config", None) if quantization_config is None: quantization_config = config_dict.get("compression_config", None) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 8d988ad5..7d6639f2 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -67,3 +67,7 @@ def set_adapter_to_index(adapter_to_index: Dict[str, int]): def get_adapter_to_index(): global ADAPTER_TO_INDEX return ADAPTER_TO_INDEX + + +USE_CUSTOM_MODELING = os.getenv("USE_CUSTOM_MODELING", "true") +USE_CUSTOM_MODELING = USE_CUSTOM_MODELING == "true" or USE_CUSTOM_MODELING == "1" diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py new file mode 100644 index 00000000..ff76b2cc --- /dev/null +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -0,0 +1,309 @@ +import math +import sys +from typing import Optional, Tuple, Dict, Any + +import torch +from opentelemetry import trace +from loguru import logger +from transformers import AutoTokenizer, AutoModelForCausalLM + +from text_generation_server.models.flash_causal_lm import ( + FlashCausalLMBatch, + FlashCausalLM, +) +from text_generation_server.utils.import_utils import ( + empty_cache, + synchronize, + get_free_memory, +) +from text_generation_server.adapters import AdapterBatchData +from text_generation_server.layers.attention import paged_attention, attention, Seqlen +from text_generation_server.layers.attention.kv_cache import KVScales +from text_generation_server.models.globals import ATTENTION +from text_generation_server.models.metadata_kernels import block_tables_to_ragged + + +tracer = trace.get_tracer(__name__) + + +def patch_everywhere( + attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None +): + """ + Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`. + + Args: + attribute_name (`str`): + The name of attribute to patch. + patch (`Any`): + The patch for the attribute. + module_name_prefix (`Optional[str]`, defaults to `None`): + If set, only module names starting with this prefix will be considered for patching. + """ + # sys.modules may be updated while being iterated over, hence the list copy. + for name in list(sys.modules): + module = sys.modules[name] + if module_name_prefix is not None and not name.startswith(module_name_prefix): + continue + if hasattr(module, attribute_name): + setattr(module, attribute_name, patch) + + +def _flash_attention_forward_patched( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + softmax_scale: Optional[float] = None, + sliding_window: int = -1, + softcap: Optional[float] = None, + **kwargs, +): + + kv_cache = kwargs["kv_cache"][kwargs["layer_idx"]] + # This means no scale + kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) + + # Correctly reshape the states + _, _, num_heads, head_dim = query_states.size() + _, _, num_kv_heads, _ = key_states.size() + query_states = query_states.view(-1, num_heads, head_dim) + key_states = key_states.view(-1, num_kv_heads, head_dim) + value_states = value_states.view(-1, num_kv_heads, head_dim) + + # Take care of updating the cache in-place + kv_cache.store( + key=key_states, + value=value_states, + slots=kwargs["slots"], + kv_scales=kv_scales + ) + + softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + + if kwargs["cu_seqlen_prefill"] is not None: + attn_output = attention( + query=query_states, + key=key_states, + value=value_states, + kv_cache=kv_cache, + kv_scales=kv_scales, + seqlen=kwargs["seqlen"], + block_tables=kwargs["block_tables"], + softmax_scale=softmax_scale, + window_size_left=sliding_window, + softcap=softcap, + ) + else: + attn_output = paged_attention( + query_states, + kv_cache, + kwargs["kv_head_mapping"], + softmax_scale, + kwargs["block_tables"], + kwargs["seqlen"], + kwargs["max_s"], + kv_scales=kv_scales, + softcap=softcap, + ) + + attn_output = attn_output.view(attn_output.shape[0], -1) + + return attn_output + + +class TransformersFlashCausalLM(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + if speculator: + raise RuntimeError("Speculator decoding is not enabled for AutoModel") + + device_count = 0 + if torch.cuda.is_available(): + device = torch.device("cuda") + device_count = torch.cuda.device_count() + dtype = torch.float16 if dtype is None else dtype + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device("xpu") + device_count = torch.xpu.device_count() + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map=("auto" if device_count > 1 else None), + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) + if device_count == 1 and quantize != "bitsandbytes": + model = model.to(device) + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None and isinstance( + model.config.eos_token_id, int + ): + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + + self.num_layers = len(model.model.layers) + self.num_kv_heads = model.config.num_key_value_heads + self.head_size = model.config.hidden_size // model.config.num_attention_heads + + # Skip FlashCausalLM init. + super(FlashCausalLM, self).__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + ) + + def warmup(self, batch: FlashCausalLMBatch): + patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched) + super().warmup(batch) + + def forward( + self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # NOTE: adapter_data: not supported + + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths = batch.input_lengths_tensor + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + + if cu_seqlen_prefill is None and self.max_past() is not None: + # In decode, not prefill, we're actually overwriting the KV-cache + # in a circular buffer mode. + # This makes sure the max_s for the decode pass is correct. + max_s = min(self.max_past(), max_s) + + bs = input_ids.shape[0] + sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) + if sorted_padded_bs: + # Get associated cuda graph + cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] + else: + cuda_graph = None + + if cu_seqlen_prefill is not None or cuda_graph is None: + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, + ) + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=cu_seqlen_prefill, + input_lengths_tensor=input_lengths, + cache_lengths_tensor=cache_lengths_tensor, + ): + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=batch.max_input_length, + max_k=batch.max_current_length, + ) + logits = self.model.forward( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=None, + use_cache=False, # we use self.kv_cache instead of transformers cache object + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=batch.prefill_cache_indices, + lm_head_indices=lm_head_indices, + ) + if batch.prefill_cache_indices is not None: + batch.prefill_cache_indices = None + return logits, None + + # Copy inputs to the static inputs of the cuda graph + # Static inputs are potentially padded + cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids + cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=batch.input_lengths, + cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, + ) + # assert block_tables.shape[0] >= slots.shape[0] + cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables + else: + cuda_graph["block_tables"][ + : block_tables.shape[0], : block_tables.shape[1] + ] = block_tables + + # XXX: This is working only because block 0 is reserved for the healthcheck + # so it doesn't matter if we override it with bogus values. + cuda_graph["slots"].fill_(0) + cuda_graph["slots"][: slots.shape[0]] = slots + cuda_graph["input_lengths"].zero_() + cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["cache_lengths"].zero_() + cuda_graph["cache_lengths"][ + : cache_lengths_tensor.shape[0] + ] = cache_lengths_tensor + + with self._forward_context( + block_tables=cuda_graph["block_tables"], + cu_seqlen_prefill=None, + input_lengths_tensor=cuda_graph["input_lengths"], + cache_lengths_tensor=cuda_graph["cache_lengths"], + state=cuda_graph["state"], + ): + # Replay the graph + cuda_graph["graph"].replay() + + # Slice output to the correct shape + logits = cuda_graph["logits"][:bs] + return logits, None