diff --git a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py index 370e05bc..aa639832 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/__init__.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/__init__.py @@ -11,6 +11,7 @@ from .hpu import ( attention, paged_attention, paged_attention_mla, + set_block_mapping, ) @@ -22,6 +23,7 @@ __all__ = [ "get_kv_scales", "paged_attention", "paged_attention_mla", + "set_block_mapping", "SUPPORTS_WINDOWING", "KVCache", "KVCompressCache", diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 8cca7a29..f12005d2 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -8,6 +8,7 @@ from habana_frameworks.torch.hpex.kernels import FusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA import os from text_generation_server.models.globals import BLOCK_SIZE +import math SUPPORTS_WINDOWING = False @@ -106,6 +107,21 @@ def attention( return attn_output +def set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size): + block_mapping = torch.nn.functional.one_hot( + hpu_attention_meta.block_groups, num_classes=batch_size + ) + dtype = hpu_attention_meta.block_usage.dtype + device = hpu_attention_meta.block_usage.device + mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) + mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1) + attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + hpu_attention_meta = hpu_attention_meta._replace( + attn_bias=attn_bias, block_mapping=block_mapping.to(dtype) + ) + return hpu_attention_meta + + def paged_attention( query: torch.Tensor, kv_cache: KVCache, @@ -176,4 +192,10 @@ def paged_attention_mla( return output.view(batch_size, head_num, -1) -__all__ = ["SUPPORTS_WINDOWING", "attention", "paged_attention", "paged_attention_mla"] +__all__ = [ + "SUPPORTS_WINDOWING", + "attention", + "paged_attention", + "paged_attention_mla", + "set_block_mapping", +] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 801ae09e..7a32a85c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -415,6 +416,10 @@ class FlashCohereModel(torch.nn.Module): seqlen: torch.Tensor, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 76972d38..42af7798 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -26,6 +26,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -678,6 +679,10 @@ class DbrxModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 6ac7fc1a..8e9002a2 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + set_block_mapping, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales @@ -569,6 +570,10 @@ class DeepseekV2Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index e0481691..8e058093 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -34,6 +34,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention_mla, + set_block_mapping, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales @@ -645,6 +646,10 @@ class DeepseekV3Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index a5860823..a1a20999 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -466,6 +467,10 @@ class FlashGemma2Model(torch.nn.Module): adapter_data: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 3d678df1..7a2ec22e 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -388,6 +389,10 @@ class FlashGemmaModel(torch.nn.Module): adapter_data: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index ed413662..a6b53656 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -27,6 +27,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -383,6 +384,10 @@ class FlashGPT2Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds residual = None diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index cde03a00..679380a1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -28,6 +28,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -324,6 +325,10 @@ class FlashGPTJModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.wte(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py index 0e3af85a..c6b68f33 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py @@ -43,6 +43,7 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.attention import ( KVCache, paged_attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -548,6 +549,10 @@ class Llama4TextModel(nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds bs = seqlen.input_lengths.shape[0] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index dfb16621..70fcc824 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -35,6 +35,7 @@ from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoE from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -549,6 +550,11 @@ class FlashLlamaModel(torch.nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], cross_attention_states=None, ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) + hidden_states = inputs_embeds # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 75d9d360..a4ad8f59 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -396,6 +397,10 @@ class MistralModel(torch.nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], adapter_data: Optional[torch.Tensor] = None, ): + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index f47986d8..4993b444 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,6 +37,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + set_block_mapping, HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales @@ -446,6 +447,10 @@ class MixtralModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -505,7 +510,6 @@ class FlashMixtralForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model( input_ids, position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 29620826..6e1050b6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -29,6 +29,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -354,6 +355,10 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_in(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 12830991..78aaf0d5 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -9,6 +9,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -347,6 +348,10 @@ class FlashPhiModel(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 7c7ac03e..ac31e53b 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -8,6 +8,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -288,6 +289,10 @@ class Qwen2Model(torch.nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( @@ -359,7 +364,6 @@ class Qwen2ForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py index 66a17877..8bd00c13 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py @@ -18,6 +18,7 @@ import habana_frameworks.torch as htorch from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -266,7 +267,10 @@ class Qwen3Model(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: - + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, inputs_embeds.shape[0] + ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -334,7 +338,6 @@ class Qwen3ForCausalLM(nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs_embeds = self.embed_tokens(input_ids) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) hidden_states = self.model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 76a2cd01..06616f85 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -18,6 +18,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( attention, paged_attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -628,6 +629,10 @@ class FlashRWModel(FlashRWPreTrainedModel): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.word_embeddings(input_ids) # Get rotary cos and sin for this forward diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c64b2ff7..b6a0d32a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -8,6 +8,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -437,6 +438,10 @@ class FlashSantacoderModel(nn.Module): seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.wte(input_ids) + self.wpe(position_ids) if self.process_group.size() > 1: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 94c60eb6..1a749595 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -29,6 +29,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, + set_block_mapping, Seqlen, HPUPagedAttentionMetadata, ) @@ -511,6 +512,10 @@ class Starcoder2Model(torch.nn.Module): adapter_data, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: + if hpu_attention_meta is not None: + hpu_attention_meta = set_block_mapping( + hpu_attention_meta, input_ids.shape[0] + ) hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -584,7 +589,6 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model( input_ids, position_ids, diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index ccec1ba6..13a2a307 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -153,19 +153,14 @@ def prepare_for_decode( block_list_device = _async_h2d_tensor_copy(block_list) block_groups_device = _async_h2d_tensor_copy(block_groups) block_usage_device = _async_h2d_tensor_copy(block_usage) - block_mapping = torch.nn.functional.one_hot( - block_groups_device, num_classes=batch_size - ) - mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= block_usage_device.unsqueeze(-1) - attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf) + return trim_attn_metadata( HPUPagedAttentionMetadata( block_list=block_list_device, block_groups=block_groups_device, block_usage=block_usage_device, - block_mapping=block_mapping.to(dtype), - attn_bias=attn_bias, + block_mapping=None, + attn_bias=None, ) ) @@ -1298,7 +1293,9 @@ class FlashCausalLMBatch(Batch): self.prefill_next_token_indices + input_ids_padded_length_tensor ) all_input_ids_tensor = torch.zeros( - (max_padded_bs, max_total_tokens), dtype=torch.int64, device="hpu" + (max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])), + dtype=torch.int64, + device="hpu", ) for i in range(len(self)): all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = ( @@ -2051,8 +2048,6 @@ class FlashCausalLM(Model): ) if batch.valid_indices is not None: # TODO speculative decoding handling missing - next_token_logprobs = next_token_logprobs.cpu() - accepted_ids = accepted_ids.cpu() index = torch.arange( 0, len(batch.valid_indices), @@ -2068,8 +2063,13 @@ class FlashCausalLM(Model): 0, index, next_input_ids[batch.valid_indices] ) next_input_ids = next_input_ids[:padded_total_bs] - next_token_logprobs = next_token_logprobs[batch.valid_indices] - accepted_ids = accepted_ids[batch.valid_indices] + + next_token_logprobs.index_copy_( + 0, index, next_token_logprobs[batch.valid_indices] + ) + accepted_ids.index_copy_( + 0, index, accepted_ids[batch.valid_indices] + ) if speculative_ids is not None: speculative_ids = speculative_ids[batch.valid_indices] batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[