diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index d71e75b4..7bcad8aa 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -44,7 +44,7 @@ def tgi_flash_attention_forward( **kwargs, ): - kv_cache = kwargs["kv_cache"][kwargs["layer_idx"]] + kv_cache = kwargs["kv_cache"][module.layer_idx] # This means no scale kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device)) @@ -97,7 +97,6 @@ def tgi_flash_attention_forward( softcap=softcap, ) - # attn_output = attn_output.view(attn_output.shape[0], -1) attn_output = attn_output.view(-1, num_heads * head_dim) return attn_output, None @@ -244,6 +243,42 @@ class TransformersFlashCausalLM(FlashCausalLM): trust_remote_code=trust_remote_code, ) + + def _model_forward( + self, + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + prefill_cache_indices, + lm_head_indices, + ): + hidden_states = self.model.model.forward( + input_ids=input_ids[None, ...], # expand dim to easily fit transformers + position_ids=position_ids[None, ...], # expand dim to easily fit transformers + past_key_values=None, # we use self.kv_cache instead of transformers cache object + use_cache=False, # we use self.kv_cache instead of transformers cache object + return_dict=True, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=prefill_cache_indices, + kv_head_mapping=self.kv_head_mapping, + )[0].squeeze(dim=0) + # And compute logits from the lm_head, slicing correctly the indices + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.model.lm_head.forward(hidden_states) + return logits + + def forward( self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -297,13 +332,9 @@ class TransformersFlashCausalLM(FlashCausalLM): max_q=batch.max_input_length, max_k=batch.max_current_length, ) - # Use only the Model, not ModelForCausalLM - hidden_states = self.model.model.forward( - input_ids=input_ids[None, ...], # expand dim to easily fit transformers - position_ids=position_ids[None, ...], - past_key_values=None, - use_cache=False, # we use self.kv_cache instead of transformers cache object - return_dict=True, + logits = self._model_forward( + input_ids=input_ids, + position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, @@ -311,10 +342,8 @@ class TransformersFlashCausalLM(FlashCausalLM): seqlen=seqlen, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, - kv_head_mapping=self.kv_head_mapping, - )[0].squeeze(dim=0) - # And compute logits from the lm_head, slicing correctly the indices - logits = self.model.lm_head.forward(hidden_states[lm_head_indices]) + lm_head_indices=lm_head_indices, + ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None return logits, None @@ -363,3 +392,180 @@ class TransformersFlashCausalLM(FlashCausalLM): # Slice output to the correct shape logits = cuda_graph["logits"][:bs] return logits, None + + + def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): + max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None + input_lengths = [max_s] * bs + cache_lengths = [0] * bs + if max_bs is None: + input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) + slots = torch.arange(bs, dtype=torch.int64, device=self.device) + input_lengths_tensor = ( + torch.ones(bs, dtype=torch.int32, device=self.device) * max_s + ) + cache_lengths_tensor = torch.zeros( + bs, dtype=torch.int32, device=self.device + ) + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(bs) + block_tables = block_tables.reshape((bs, max_bt)) + if ATTENTION == "flashinfer": + block_tables = block_tables_to_ragged( + block_tables=block_tables, + input_lengths=input_lengths, + cache_lengths=cache_lengths, + input_lengths_tensor=input_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, + max_current_length=max_s, + ) + else: + if bs > max_bs: + raise RuntimeError( + "Cuda graphs should be generated in decreasing order size to reduce VRAM usage" + ) + input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs] + position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs] + if ATTENTION == "flashinfer": + block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt] + else: + block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs] + slots = self.cuda_graphs[max_bs]["slots"][:bs] + input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs] + cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs] + + if ATTENTION == "flashinfer": + from text_generation_server.layers.attention.flashinfer import ( + create_decode_state_cuda_graphs, + ) + + block_tables_ptr = torch.zeros( + bs + 1, dtype=torch.int32, device=self.device + ) + last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device) + state = create_decode_state_cuda_graphs( + device=input_ids.device, + block_tables=block_tables, + block_tables_ptr=block_tables_ptr, + last_page_len=last_page_len, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + else: + state = None + + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "model_type") + and self.model.config.model_type == "qwen2_vl" + ): + if position_ids.dim() == 1: + position_ids = self.model.get_position_ids(input_ids) + + graph = torch.cuda.CUDAGraph() + self.cuda_graphs[bs] = { + "input_ids": input_ids, + "position_ids": position_ids, + "kv_cache": self.kv_cache, + "block_tables": block_tables, + "slots": slots, + "input_lengths": input_lengths_tensor, + "cache_lengths": cache_lengths_tensor, + "state": state, + "graph": graph, + } + + torch.cuda.synchronize() + # Run once outside to warmup + with self._forward_context( + block_tables=block_tables, + cu_seqlen_prefill=None, + input_lengths_tensor=input_lengths_tensor, + state=state, + cache_lengths_tensor=cache_lengths_tensor, + ): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + self._model_forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + del seqlen + + torch.cuda.synchronize() + + with torch.cuda.graph(graph, pool=MEM_POOL): + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) + logits = self._model_forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=None, + kv_cache=self.kv_cache, + block_tables=block_tables, + slots=slots, + seqlen=seqlen, + max_s=max_s, + prefill_cache_indices=None, + lm_head_indices=None, + ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = None + torch.cuda.synchronize() + + + def tunableop_warmup(self, seqlen: int): + input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) + position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) + + # Dummy value, some models (starcoder2) don't accept `None`. + input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + cache_lengths_tensor = torch.zeros( + seqlen, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.tensor( + [0, seqlen], device=self.device, dtype=torch.int32 + ) + max_s = seqlen + seqlen = Seqlen( + input_lengths=input_lengths, + cache_lengths=cache_lengths_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=1, + max_k=seqlen, + ) + + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + self._model_forward( + input_ids=input_ids, + position_ids=position_ids, + cu_seqlen_prefill=cu_seqlen_prefill, + kv_cache=self.kv_cache, + block_tables=None, + seqlen=seqlen, + slots=slots, + max_s=max_s, + lm_head_indices=None, + prefill_cache_indices=None, + ) \ No newline at end of file