diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 98fbf9a2..cfc9c861 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -1,32 +1,17 @@ import math -import sys -from typing import List, Optional, Tuple, Dict, Any +from typing import List, Optional import torch from opentelemetry import trace -from loguru import logger from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig import transformers.modeling_utils -from text_generation_server.models.flash_causal_lm import ( - FlashCausalLMBatch, - FlashCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import ( - empty_cache, - synchronize, - get_free_memory, -) -from text_generation_server.adapters import AdapterBatchData +from text_generation_server.models.flash_causal_lm import FlashCausalLM +from text_generation_server.utils import initialize_torch_distributed + from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention.kv_cache import KVScales, KVCache from text_generation_server.models.globals import ATTENTION -from text_generation_server.models.metadata_kernels import block_tables_to_ragged tracer = trace.get_tracer(__name__) @@ -48,7 +33,7 @@ def tgi_flash_attention_forward( softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, - **_kwargs, # This is needed to "absorb" other args passed by Transformers modeling + **kwargs, # This is needed to "absorb" other args passed by Transformers modeling ): kv_cache = kv_cache[module.layer_idx] @@ -222,6 +207,11 @@ class TransformersFlashCausalLM(FlashCausalLM): world_size=world_size, ) + # Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code + # We first copy the original model.forward because we still need it in the monkey patch + self.model.original_forward = self.model.forward + self.model.forward = self._model_forward + @classmethod def fallback( cls, @@ -252,12 +242,15 @@ class TransformersFlashCausalLM(FlashCausalLM): seqlen: Seqlen, max_s: int, lm_head_indices: Optional[torch.Tensor], + prefill_cache_indices = None, # not used, but passed to match original signature + adapter_data = None, # not supported, but passed to match original signature ): # Transformers does not support None as a default if lm_head_indices is None: lm_head_indices = 0 - logits = self.model.forward( + # Equivalent tp `self.model.forward`, see the monkey patch in __init__ + logits = self.model.original_forward( input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers past_key_values=None, # we use self.kv_cache instead of transformers cache object @@ -272,292 +265,5 @@ class TransformersFlashCausalLM(FlashCausalLM): max_s=max_s, kv_head_mapping=self.kv_head_mapping, ).logits.squeeze(dim=0) - return logits - - def forward( - self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # NOTE: adapter_data: not supported - - input_ids = batch.input_ids - position_ids = batch.position_ids - cu_seqlen_prefill = batch.cu_seqlen_prefill - kv_cache = self.kv_cache - block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor - cache_lengths_tensor = batch.cache_lengths_tensor - max_s = batch.max_current_length - lm_head_indices = batch.prefill_head_indices - - if cu_seqlen_prefill is None and self.max_past() is not None: - # In decode, not prefill, we're actually overwriting the KV-cache - # in a circular buffer mode. - # This makes sure the max_s for the decode pass is correct. - max_s = min(self.max_past(), max_s) - - bs = input_ids.shape[0] - sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) - if sorted_padded_bs: - # Get associated cuda graph - cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] - else: - cuda_graph = None - - if cu_seqlen_prefill is not None or cuda_graph is None: - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - cache_lengths=batch.cache_lengths, - input_lengths_tensor=batch.input_lengths_tensor, - cache_lengths_tensor=batch.cache_lengths_tensor, - max_current_length=batch.max_current_length, - ) - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, - cache_lengths_tensor=cache_lengths_tensor, - ): - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=batch.max_input_length, - max_k=batch.max_current_length, - ) - logits = self._model_forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - lm_head_indices=lm_head_indices, - ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - return logits, None - - # Copy inputs to the static inputs of the cuda graph - # Static inputs are potentially padded - cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids - cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=batch.input_lengths, - cache_lengths=batch.cache_lengths, - input_lengths_tensor=batch.input_lengths_tensor, - cache_lengths_tensor=batch.cache_lengths_tensor, - max_current_length=batch.max_current_length, - ) - # assert block_tables.shape[0] >= slots.shape[0] - cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables - else: - cuda_graph["block_tables"][ - : block_tables.shape[0], : block_tables.shape[1] - ] = block_tables - - # XXX: This is working only because block 0 is reserved for the healthcheck - # so it doesn't matter if we override it with bogus values. - cuda_graph["slots"].fill_(0) - cuda_graph["slots"][: slots.shape[0]] = slots - cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths - cuda_graph["cache_lengths"].zero_() - cuda_graph["cache_lengths"][ - : cache_lengths_tensor.shape[0] - ] = cache_lengths_tensor - - with self._forward_context( - block_tables=cuda_graph["block_tables"], - cu_seqlen_prefill=None, - input_lengths_tensor=cuda_graph["input_lengths"], - cache_lengths_tensor=cuda_graph["cache_lengths"], - state=cuda_graph["state"], - ): - # Replay the graph - cuda_graph["graph"].replay() - - # Slice output to the correct shape - logits = cuda_graph["logits"][:bs] - return logits, None - - - def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): - max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None - input_lengths = [max_s] * bs - cache_lengths = [0] * bs - if max_bs is None: - input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) - slots = torch.arange(bs, dtype=torch.int64, device=self.device) - input_lengths_tensor = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * max_s - ) - cache_lengths_tensor = torch.zeros( - bs, dtype=torch.int32, device=self.device - ) - block_tables = torch.arange( - max_bt, dtype=torch.int32, device=self.device - ).repeat(bs) - block_tables = block_tables.reshape((bs, max_bt)) - if ATTENTION == "flashinfer": - block_tables = block_tables_to_ragged( - block_tables=block_tables, - input_lengths=input_lengths, - cache_lengths=cache_lengths, - input_lengths_tensor=input_lengths_tensor, - cache_lengths_tensor=cache_lengths_tensor, - max_current_length=max_s, - ) - else: - if bs > max_bs: - raise RuntimeError( - "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" - ) - input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] - position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] - if ATTENTION == "flashinfer": - block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] - else: - block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs] - slots = self.cuda_graphs[max_bs]["slots"][:bs] - input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs] - cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs] - - if ATTENTION == "flashinfer": - from text_generation_server.layers.attention.flashinfer import ( - create_decode_state_cuda_graphs, - ) - - block_tables_ptr = torch.zeros( - bs + 1, dtype=torch.int32, device=self.device - ) - last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) - state = create_decode_state_cuda_graphs( - device=input_ids.device, - block_tables=block_tables, - block_tables_ptr=block_tables_ptr, - last_page_len=last_page_len, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) - else: - state = None - - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "model_type") - and self.model.config.model_type == "qwen2_vl" - ): - if position_ids.dim() == 1: - position_ids = self.model.get_position_ids(input_ids) - - graph = torch.cuda.CUDAGraph() - self.cuda_graphs[bs] = { - "input_ids": input_ids, - "position_ids": position_ids, - "kv_cache": self.kv_cache, - "block_tables": block_tables, - "slots": slots, - "input_lengths": input_lengths_tensor, - "cache_lengths": cache_lengths_tensor, - "state": state, - "graph": graph, - } - - torch.cuda.synchronize() - # Run once outside to warmup - with self._forward_context( - block_tables=block_tables, - cu_seqlen_prefill=None, - input_lengths_tensor=input_lengths_tensor, - state=state, - cache_lengths_tensor=cache_lengths_tensor, - ): - seqlen = Seqlen( - input_lengths=input_lengths_tensor, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=None, - max_q=1, - max_k=max_s, - ) - self._model_forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - lm_head_indices=None, - ) - del seqlen - - torch.cuda.synchronize() - - with torch.cuda.graph(graph, pool=MEM_POOL): - seqlen = Seqlen( - input_lengths=input_lengths_tensor, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=None, - max_q=1, - max_k=max_s, - ) - logits = self._model_forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=None, - kv_cache=self.kv_cache, - block_tables=block_tables, - slots=slots, - seqlen=seqlen, - max_s=max_s, - lm_head_indices=None, - ) - self.cuda_graphs[bs]["logits"] = logits - self.cuda_graphs[bs]["speculative_logits"] = None - torch.cuda.synchronize() - - - def tunableop_warmup(self, seqlen: int): - input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) - position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) - slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) - - # Dummy value, some models (starcoder2) don't accept `None`. - input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - cache_lengths_tensor = torch.zeros( - seqlen, dtype=torch.int32, device=self.device - ) - cu_seqlen_prefill = torch.tensor( - [0, seqlen], device=self.device, dtype=torch.int32 - ) - max_s = seqlen - seqlen = Seqlen( - input_lengths=input_lengths, - cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=1, - max_k=seqlen, - ) - - # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. - self._model_forward( - input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - kv_cache=self.kv_cache, - block_tables=None, - seqlen=seqlen, - slots=slots, - max_s=max_s, - lm_head_indices=None, - ) \ No newline at end of file + return logits, None \ No newline at end of file