From 839477670aed6498c74b785585ad321ef5f7b3c7 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Wed, 11 Jun 2025 21:00:21 +0800 Subject: [PATCH] [gaudi] Perf optimization (#3256) Signed-off-by: Wang, Yi A --- .../layers/attention/__init__.py | 2 + .../layers/attention/hpu.py | 24 ++- .../custom_modeling/flash_cohere_modeling.py | 5 + .../custom_modeling/flash_dbrx_modeling.py | 5 + .../flash_deepseek_v2_modeling.py | 5 + .../flash_deepseek_v3_modeling.py | 5 + .../custom_modeling/flash_gemma2_modeling.py | 5 + .../custom_modeling/flash_gemma_modeling.py | 5 + .../custom_modeling/flash_gpt2_modeling.py | 5 + .../custom_modeling/flash_gptj_modeling.py | 5 + .../custom_modeling/flash_llama4_modeling.py | 5 + .../custom_modeling/flash_llama_modeling.py | 6 + .../custom_modeling/flash_mistral_modeling.py | 5 + .../custom_modeling/flash_mixtral_modeling.py | 6 +- .../custom_modeling/flash_neox_modeling.py | 5 + .../custom_modeling/flash_phi_modeling.py | 5 + .../custom_modeling/flash_qwen2_modeling.py | 6 +- .../custom_modeling/flash_qwen3_modeling.py | 7 +- .../custom_modeling/flash_rw_modeling.py | 5 + .../flash_santacoder_modeling.py | 5 + .../flash_starcoder2_modeling.py | 6 +- .../models/flash_causal_lm.py | 160 ++++++++++++------ .../models/flash_vlm_causal_lm.py | 4 +- .../models/mllama_causal_lm.py | 4 +- 24 files changed, 229 insertions(+), 66 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py index 370e05bc..aa639832 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -11,6 +11,7 @@ from .hpu import ( attention, paged_attention, paged_attention_mla, + set_block_mapping, ) @@ -22,6 +23,7 @@ __all__ = [ "get_kv_scales", "paged_attention", "paged_attention_mla", + "set_block_mapping", "SUPPORTS_WINDOWING", "KVCache", "KVCompressCache", diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 8cca7a29..f12005d2 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -8,6 +8,7 @@ from habana_frameworks.torch.hpex.kernels import FusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA import os from text_generation_server.models.globals import BLOCK_SIZE +import math SUPPORTS_WINDOWING = False @@ -106,6 +107,21 @@ def attention( return attn_output +def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size): + block_mapping = torch.nn.functional.one_hot( + hpu_attention_meta.block_groups, num_classes=batch_size + ) + dtype = hpu_attention_meta.block_usage.dtype + device = hpu_attention_meta.block_usage.device + mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) + mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1) + attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + hpu_attention_meta = hpu_attention_meta._replace( + attn_bias=attn_bias, block_mapping=block_mapping.to(dtype) + ) + return hpu_attention_meta + + def paged_attention( query: torch.Tensor, kv_cache: KVCache, @@ -176,4 +192,10 @@ def paged_attention_mla( return output.view(batch_size, head_num, -1) -__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"] +__all__ = [ + "SUPPORTS_WINDOWING", + "attention", + "paged_attention", + "paged_attention_mla", + "set_block_mapping", +] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 801ae09e..7a32a85c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -415,6 +416,10 @@ class FlashCohereModel(torch.nn.Module): seqlen: torch.Tensor, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 76972d38..42af7798 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -26,6 +26,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -678,6 +679,10 @@ class DbrxModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 6ac7fc1a..8e9002a2 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + set_block_mapping, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales @@ -569,6 +570,10 @@ class DeepseekV2Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index e0481691..8e058093 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -34,6 +34,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention_mla, + set_block_mapping, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales @@ -645,6 +646,10 @@ class DeepseekV3Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index a5860823..a1a20999 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -466,6 +467,10 @@ class FlashGemma2Model(torch.nn.Module): adapter_data: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 3d678df1..7a2ec22e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -388,6 +389,10 @@ class FlashGemmaModel(torch.nn.Module): adapter_data: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index ed413662..a6b53656 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -27,6 +27,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -383,6 +384,10 @@ class FlashGPT2Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds residual = None diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index cde03a00..679380a1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -28,6 +28,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -324,6 +325,10 @@ class FlashGPTJModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.wte(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 0e3af85a..c6b68f33 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -43,6 +43,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.attention import ( KVCache, paged_attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -548,6 +549,10 @@ class Llama4TextModel(nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds bs = seqlen.input_lengths.shape[0] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index dfb16621..70fcc824 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -35,6 +35,7 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -549,6 +550,11 @@ class FlashLlamaModel(torch.nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], cross_attention_states=None, ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) + hidden_states = inputs_embeds # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 75d9d360..a4ad8f59 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -396,6 +397,10 @@ class MistralModel(torch.nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], adapter_data: Optional[torch.Tensor] = None, ): + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index f47986d8..4993b444 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,6 +37,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + set_block_mapping, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales @@ -446,6 +447,10 @@ class MixtralModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -505,7 +510,6 @@ class FlashMixtralForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model( input_ids, position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 29620826..6e1050b6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -29,6 +29,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -354,6 +355,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_in(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 12830991..78aaf0d5 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -9,6 +9,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -347,6 +348,10 @@ class FlashPhiModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 7c7ac03e..ac31e53b 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -8,6 +8,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -288,6 +289,10 @@ class Qwen2Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( @@ -359,7 +364,6 @@ class Qwen2ForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py index 66a17877..8bd00c13 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -18,6 +18,7 @@ import habana_frameworks.torch as htorch from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -266,7 +267,10 @@ class Qwen3Model(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: - + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -334,7 +338,6 @@ class Qwen3ForCausalLM(nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) hidden_states = self.model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 76a2cd01..06616f85 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -18,6 +18,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( attention, paged_attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -628,6 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.word_embeddings(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c64b2ff7..b6a0d32a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -8,6 +8,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -437,6 +438,10 @@ class FlashSantacoderModel(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.wte(input_ids) + self.wpe(position_ids) if self.process_group.size() > 1: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 94c60eb6..1a749595 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -29,6 +29,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -511,6 +512,10 @@ class Starcoder2Model(torch.nn.Module): adapter_data, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -584,7 +589,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model( input_ids, position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index f8abe5ad..13a2a307 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -153,19 +153,14 @@ def prepare_for_decode( block_list_device = _async_h2d_tensor_copy(block_list) block_groups_device = _async_h2d_tensor_copy(block_groups) block_usage_device = _async_h2d_tensor_copy(block_usage) - block_mapping = torch.nn.functional.one_hot( - block_groups_device, num_classes=batch_size - ) - mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= block_usage_device.unsqueeze(-1) - attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + return trim_attn_metadata( HPUPagedAttentionMetadata( block_list=block_list_device, block_groups=block_groups_device, block_usage=block_usage_device, - block_mapping=block_mapping.to(dtype), - attn_bias=attn_bias, + block_mapping=None, + attn_bias=None, ) ) @@ -428,10 +423,8 @@ class FlashCausalLMBatch(Batch): for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids - # Create tensors on device - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) + # put on cpu temporarily, move to hpu in prepare_for_prefill + all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64) top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64) @@ -701,7 +694,9 @@ class FlashCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": + def concatenate( + cls, batches: List["FlashCausalLMBatch"], padded_total_bs: int = 0 + ) -> "FlashCausalLMBatch": # Batch attributes requests = [] requests_idx_mapping = {} @@ -750,7 +745,10 @@ class FlashCausalLMBatch(Batch): adapter_meta = None adapter_segment_builder = None else: - input_ids = batches[0].input_ids.new_empty(total_batch_size) + if padded_total_bs == batches[0].input_ids.shape[0]: + input_ids = batches[0].input_ids + else: + input_ids = batches[0].input_ids.new_empty(total_batch_size) if ( batches[0].position_ids is not None and batches[0].position_ids.dim() == 2 @@ -784,9 +782,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) - all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( - (total_batch_size, max_length) - ) + all_input_ids_tensor = batches[0].all_input_ids_tensor top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) @@ -829,9 +825,12 @@ class FlashCausalLMBatch(Batch): index = torch.tensor(list(range(start_index, end_index)), device="cpu") top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor) - all_input_ids_tensor[ - start_index:end_index, : batch.all_input_ids_tensor.shape[1] - ] = batch.all_input_ids_tensor[:valid_bsize, :max_length] + if i > 0: + all_input_ids_tensor.index_copy_( + 0, + index.to(batch.all_input_ids_tensor.device), + batch.all_input_ids_tensor[:valid_bsize, :], + ) block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] @@ -851,9 +850,10 @@ class FlashCausalLMBatch(Batch): ) if not prefilling: - input_ids.index_copy_( - 0, index.to(input_ids.device), batch.input_ids[:valid_bsize] - ) + if padded_total_bs != batches[0].input_ids.shape[0] or i > 0: + input_ids.index_copy_( + 0, index.to(input_ids.device), batch.input_ids[:valid_bsize] + ) position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize]) slot_indices.index_copy_( 0, index, batch.slot_indices + cumulative_slots @@ -987,7 +987,6 @@ class FlashCausalLMBatch(Batch): else: padded_bs = self.input_ids.shape[0] slots = self.slots[self.slot_indices] - extra_pad = padded_bs - self.input_ids.shape[0] self.hpu_attn_meta = prepare_for_decode( dtype, @@ -998,17 +997,20 @@ class FlashCausalLMBatch(Batch): padded_bs, bucketing_ctx, ) - self.input_ids = F.pad(self.input_ids, (0, extra_pad), value=0) - self.position_ids = F.pad(self.position_ids, (0, extra_pad), value=1) + self.input_ids = F.pad( + self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=0 + ) + self.position_ids = F.pad( + self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1 + ) self.input_lengths_tensor = F.pad( - self.input_lengths_tensor, (0, extra_pad), value=0 + self.input_lengths_tensor, + (0, padded_bs - self.input_lengths_tensor.shape[0]), + value=0, ) self.cache_lengths_tensor = F.pad( - self.cache_lengths_tensor, (0, extra_pad), value=0 - ) - self.all_input_ids_tensor = F.pad( - self.all_input_ids_tensor, - (0, 0, 0, extra_pad), + self.cache_lengths_tensor, + (0, padded_bs - self.cache_lengths_tensor.shape[0]), value=0, ) next_token_chooser_parameters = [] @@ -1028,7 +1030,9 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states, ) - def prepare_for_prefill(self, max_padded_input_len, max_padded_bs): + def prepare_for_prefill( + self, max_padded_input_len, max_padded_bs, max_total_tokens + ): # Prepare values if we need to continue prefilling # Speculation must be ignored while we prefill even with chunking # it simplifies everything @@ -1044,7 +1048,7 @@ class FlashCausalLMBatch(Batch): # need extra pad to match warmup seq extra_pad = max_padded_input_len - self.max_input_length extra_pad_bs = max_padded_bs - len(self) - device = self.all_input_ids_tensor.device + device = "hpu" if isinstance(self.input_ids, list) and len(self) > 1: input_ids_padded_length = [] input_ids = [] @@ -1288,12 +1292,17 @@ class FlashCausalLMBatch(Batch): self.prefill_next_token_indices = ( self.prefill_next_token_indices + input_ids_padded_length_tensor ) - - self.all_input_ids_tensor = F.pad( - self.all_input_ids_tensor, - (0, 0, 0, extra_pad_bs), - value=0, + all_input_ids_tensor = torch.zeros( + (max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])), + dtype=torch.int64, + device="hpu", ) + for i in range(len(self)): + all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = ( + self.all_input_ids_tensor[i] + ) + self.all_input_ids_tensor = all_input_ids_tensor + next_token_chooser_parameters = [] next_token_chooser_parameters.extend([r.parameters for r in self.requests]) pad_next_token_chooser_parameters(next_token_chooser_parameters, max_padded_bs) @@ -1459,6 +1468,8 @@ class FlashCausalLM(Model): self.kv_cache = [] self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype self.bucketing_ctx = None + self.max_total_tokens = None + self.max_input_tokens = None htorch.core.hpu_set_env() if htorch.utils.internal.is_lazy(): htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True) @@ -1564,6 +1575,14 @@ class FlashCausalLM(Model): logger.info, f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}", ) + if max_total_tokens is None: + max_total_tokens = sum(batch.input_lengths) + + if max_input_tokens is None: + max_input_tokens = max_total_tokens - 1 + + self.max_total_tokens = max_total_tokens + self.max_input_tokens = max_input_tokens try: self.init_kv_cache( batch.num_blocks, @@ -1597,11 +1616,6 @@ class FlashCausalLM(Model): ) log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") - if max_total_tokens is None: - max_total_tokens = sum(batch.input_lengths) - - if max_input_tokens is None: - max_input_tokens = max_total_tokens - 1 self.kv_cache = [] empty_cache() @@ -2017,7 +2031,9 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_current_length], + batch.all_input_ids_tensor[ + : batch.next_token_logits.shape[0], : batch.max_current_length + ], batch.next_token_logits, speculate, batch.speculative_ids, @@ -2031,14 +2047,29 @@ class FlashCausalLM(Model): accepted_ids, ) if batch.valid_indices is not None: - next_token_logprobs = next_token_logprobs.cpu() - accepted_ids = accepted_ids.cpu() - batch.all_input_ids_tensor = batch.all_input_ids_tensor[ - batch.valid_indices - ] - next_input_ids = next_input_ids[batch.valid_indices] - next_token_logprobs = next_token_logprobs[batch.valid_indices] - accepted_ids = accepted_ids[batch.valid_indices] + # TODO speculative decoding handling missing + index = torch.arange( + 0, + len(batch.valid_indices), + device=batch.all_input_ids_tensor.device, + ) + batch.all_input_ids_tensor.index_copy_( + 0, index, batch.all_input_ids_tensor[batch.valid_indices] + ) + padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size( + len(batch.valid_indices) + ) + next_input_ids.index_copy_( + 0, index, next_input_ids[batch.valid_indices] + ) + next_input_ids = next_input_ids[:padded_total_bs] + + next_token_logprobs.index_copy_( + 0, index, next_token_logprobs[batch.valid_indices] + ) + accepted_ids.index_copy_( + 0, index, accepted_ids[batch.valid_indices] + ) if speculative_ids is not None: speculative_ids = speculative_ids[batch.valid_indices] batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[ @@ -2106,10 +2137,13 @@ class FlashCausalLM(Model): batch.slot_indices += accepted_ids[: len(batch)] else: index = batch.cache_lengths_tensor + batch.input_lengths_tensor + index = F.pad( + index, (0, next_input_ids.shape[0] - index.shape[0]), value=0 + ) index = index.to(batch.all_input_ids_tensor.device) batch_idx = torch.arange( 0, - batch.all_input_ids_tensor.shape[0], + index.shape[0], dtype=torch.long, device=batch.all_input_ids_tensor.device, ) @@ -2197,7 +2231,18 @@ class FlashCausalLM(Model): htorch.core.mark_step() # Stage 2. Prepare new batch for speculative scheduling if len(batches) > 1: - batch = self.batch_type.concatenate(batches) + if self.bucketing_ctx is not None: + total_batch_size = 0 + for b in batches: + total_batch_size += len(b) + padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size( + total_batch_size + ) + batch = self.batch_type.concatenate( + batches, padded_total_bs=padded_total_bs + ) + else: + batch = self.batch_type.concatenate(batches) else: batch = batches[0] prefill = batch.prefilling @@ -2208,9 +2253,12 @@ class FlashCausalLM(Model): batch.max_input_length ), self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)), + self.max_total_tokens, ) else: - batch.prepare_for_prefill(batch.max_input_length, len(batch)) + batch.prepare_for_prefill( + batch.max_input_length, len(batch), self.max_total_tokens + ) else: batch.prepare_for_decode( self.dtype, self.use_contiguous_pa, self.bucketing_ctx diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index e604fd3c..9755ee6d 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -262,8 +262,8 @@ class FlashVlmCausalLMBatch(FlashCausalLMBatch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches): - batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches) + def concatenate(cls, batches, padded_total_bs: int = 0): + batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs) batch.pixel_values = None batch.pixel_attention_mask = None batch.image_sizes = None diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 771cc0a8..13939974 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -48,8 +48,8 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches): - batch = super().concatenate(batches) + def concatenate(cls, batches, padded_total_bs: int = 0): + batch = super().concatenate(batches, padded_total_bs) batch.pixel_values = None batch.pixel_attention_mask = None