From 1508ee8de125d97a305807553537a9b5487e70d5 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 27 Mar 2025 22:51:21 -0700 Subject: [PATCH] remove block_tables and prefill_cache_indices which will lead to dynamic shape Signed-off-by: Wang, Yi A --- .../layers/attention/hpu.py | 2 - .../custom_modeling/flash_cohere_modeling.py | 27 +--- .../custom_modeling/flash_dbrx_modeling.py | 28 +--- .../flash_deepseek_v2_modeling.py | 27 +--- .../flash_deepseek_v3_modeling.py | 26 +--- .../custom_modeling/flash_gemma2_modeling.py | 24 +--- .../custom_modeling/flash_gemma_modeling.py | 25 +--- .../custom_modeling/flash_gpt2_modeling.py | 27 +--- .../custom_modeling/flash_llama_modeling.py | 30 +---- .../custom_modeling/flash_llava_next.py | 4 - .../custom_modeling/flash_mistral_modeling.py | 31 +---- .../custom_modeling/flash_mixtral_modeling.py | 30 +---- .../models/custom_modeling/flash_mllama.py | 6 - .../custom_modeling/flash_neox_modeling.py | 26 +--- .../flash_pali_gemma_modeling.py | 4 - .../custom_modeling/flash_phi_modeling.py | 24 +--- .../custom_modeling/flash_qwen2_modeling.py | 31 +---- .../custom_modeling/flash_rw_modeling.py | 44 +----- .../flash_santacoder_modeling.py | 25 +--- .../flash_starcoder2_modeling.py | 30 +---- .../models/custom_modeling/idefics2.py | 4 - .../models/custom_modeling/idefics3.py | 4 - .../models/custom_modeling/qwen2_5_vl.py | 4 - .../models/custom_modeling/qwen2_vl.py | 4 - .../models/flash_causal_lm.py | 127 ++++++++---------- .../models/flash_vlm_causal_lm.py | 2 - .../models/mllama_causal_lm.py | 2 - 27 files changed, 88 insertions(+), 530 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py index 56143541..526dbcec 100644 --- a/backends/gaudi/server/text_generation_server/layers/attention/hpu.py +++ b/backends/gaudi/server/text_generation_server/layers/attention/hpu.py @@ -26,7 +26,6 @@ def attention( kv_cache: KVCache, kv_scales: KVScales, seqlen: Seqlen, - block_tables: torch.Tensor, softmax_scale: float, window_size_left: int = -1, causal: bool = True, @@ -61,7 +60,6 @@ def paged_attention( kv_cache: KVCache, kv_head_mapping: torch.Tensor, softmax_scale: float, - block_tables: torch.Tensor, seqlen: Seqlen, *, kv_scales: KVScales, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 77dec80d..3bcc689d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -219,10 +219,8 @@ class FlashCohereAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -247,16 +245,9 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) - if prefill_cache_indices is not None: - key_to_cache = key[prefill_cache_indices] - value_to_cache = value[prefill_cache_indices] - else: - key_to_cache = key - value_to_cache = value - kv_cache.store( - key=key_to_cache, - value=value_to_cache, + key=key, + value=value, slots=slots, kv_scales=self.kv_scales, ) @@ -271,7 +262,6 @@ class FlashCohereAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -281,7 +271,6 @@ class FlashCohereAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -356,10 +345,8 @@ class FlashCohereLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -371,10 +358,8 @@ class FlashCohereLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -424,10 +409,8 @@ class FlashCohereModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: torch.Tensor, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -446,10 +429,8 @@ class FlashCohereModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -488,10 +469,8 @@ class FlashCohereForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -501,10 +480,8 @@ class FlashCohereForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index b335a81f..15c243c9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -308,10 +308,8 @@ class DbrxAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -329,14 +327,10 @@ class DbrxAttention(torch.nn.Module): kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -351,7 +345,6 @@ class DbrxAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -361,7 +354,6 @@ class DbrxAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -398,10 +390,8 @@ class DbrxNormAttentionNorm(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): normed_hidden_states, res = self.norm_1(hidden_states, residual) @@ -413,10 +403,8 @@ class DbrxNormAttentionNorm(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -630,10 +618,8 @@ class DbrxLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): # Self Attention @@ -644,10 +630,8 @@ class DbrxLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -689,10 +673,8 @@ class DbrxModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -710,10 +692,8 @@ class DbrxModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -744,10 +724,8 @@ class FlashDbrxForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -757,10 +735,8 @@ class FlashDbrxForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 3298a30a..9d61c694 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -256,10 +256,8 @@ class DeepseekV2Attention(torch.nn.Module): sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, kv_cache: KVCache, - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: @@ -316,15 +314,10 @@ class DeepseekV2Attention(torch.nn.Module): value = torch.nn.functional.pad( value, (0, self.head_pad_size - self.value_head_size), value=0 ) - if prefill_cache_indices is not None: - key_to_cache = key[prefill_cache_indices] - value_to_cache = value[prefill_cache_indices] - else: - key_to_cache = key - value_to_cache = value + kv_cache.store( - key=key_to_cache, - value=value_to_cache, + key=key, + value=value, slots=slots, kv_scales=self.kv_scales, ) @@ -339,7 +332,6 @@ class DeepseekV2Attention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -349,7 +341,6 @@ class DeepseekV2Attention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -512,10 +503,8 @@ class DeepseekV2Layer(nn.Module): sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, kv_cache, - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -527,10 +516,8 @@ class DeepseekV2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -577,10 +564,8 @@ class DeepseekV2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -598,10 +583,8 @@ class DeepseekV2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -629,10 +612,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -642,10 +623,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index 736e0c9a..1a7ce5cf 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -256,10 +256,8 @@ class DeepseekV3Attention(torch.nn.Module): sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, kv_cache: KVCache, - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: @@ -317,15 +315,9 @@ class DeepseekV3Attention(torch.nn.Module): value, (0, self.head_pad_size - self.value_head_size), value=0 ) - if prefill_cache_indices is not None: - key_to_cache = key[prefill_cache_indices] - value_to_cache = value[prefill_cache_indices] - else: - key_to_cache = key - value_to_cache = value kv_cache.store( - key=key_to_cache, - value=value_to_cache, + key=key, + value=value, slots=slots, kv_scales=self.kv_scales, ) @@ -340,7 +332,6 @@ class DeepseekV3Attention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -350,7 +341,6 @@ class DeepseekV3Attention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -522,10 +512,8 @@ class DeepseekV3Layer(nn.Module): sin: torch.Tensor, cu_seqlen_prefill: torch.Tensor, kv_cache, - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -537,10 +525,8 @@ class DeepseekV3Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -587,10 +573,8 @@ class DeepseekV3Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -608,10 +592,8 @@ class DeepseekV3Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -639,10 +621,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -652,10 +632,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 632e8017..79f21b0f 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -235,11 +235,9 @@ class FlashGemma2Attention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states, adapter_data) @@ -254,14 +252,10 @@ class FlashGemma2Attention(torch.nn.Module): kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -276,7 +270,6 @@ class FlashGemma2Attention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, window_size_left=self.window_size, softcap=self.softcap, @@ -288,7 +281,6 @@ class FlashGemma2Attention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, softcap=self.softcap, kv_scales=self.kv_scales, @@ -402,11 +394,9 @@ class FlashGemma2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -418,11 +408,9 @@ class FlashGemma2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta, ) @@ -472,11 +460,9 @@ class FlashGemma2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, adapter_data: Optional[torch.Tensor], - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -494,11 +480,9 @@ class FlashGemma2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta, ) @@ -543,10 +527,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -557,11 +539,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index c3e5727b..609f03ac 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -207,10 +207,8 @@ class FlashGemmaAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -226,14 +224,9 @@ class FlashGemmaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -248,7 +241,6 @@ class FlashGemmaAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, causal=self.causal, ) @@ -259,7 +251,6 @@ class FlashGemmaAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -331,10 +322,8 @@ class FlashGemmaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -346,10 +335,8 @@ class FlashGemmaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -395,11 +382,9 @@ class FlashGemmaModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, adapter_data: Optional[torch.Tensor], - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -417,10 +402,8 @@ class FlashGemmaModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -463,10 +446,8 @@ class FlashGemmaForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -477,11 +458,9 @@ class FlashGemmaForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a7a85d3a..10024a6d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -209,10 +209,8 @@ class FlashGPT2Attention(torch.nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): query, key, value = self.query_key_value(hidden_states).split( @@ -222,16 +220,9 @@ class FlashGPT2Attention(torch.nn.Module): key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) - if prefill_cache_indices is not None: - key_to_cache = key[prefill_cache_indices] - value_to_cache = value[prefill_cache_indices] - else: - key_to_cache = key - value_to_cache = value - kv_cache.store( - key=key_to_cache, - value=value_to_cache, + key=key, + value=value, slots=slots, kv_scales=self.kv_scales, ) @@ -246,7 +237,6 @@ class FlashGPT2Attention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -256,7 +246,6 @@ class FlashGPT2Attention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -325,10 +314,8 @@ class FlashGPT2Layer(nn.Module): residual, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): residual = hidden_states @@ -339,10 +326,8 @@ class FlashGPT2Layer(nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -393,10 +378,8 @@ class FlashGPT2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -408,10 +391,8 @@ class FlashGPT2Model(torch.nn.Module): residual, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -446,11 +427,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -462,10 +441,8 @@ class FlashGPT2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices=prefill_cache_indices, hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 7deb6cbf..81af5560 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -201,11 +201,9 @@ class FlashLlamaAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache: KVCache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): qkv = self.query_key_value(hidden_states, adapter_data) @@ -221,14 +219,9 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -243,7 +236,6 @@ class FlashLlamaAttention(torch.nn.Module): kv_scales=self.kv_scales, kv_cache=kv_cache, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -253,7 +245,6 @@ class FlashLlamaAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -441,12 +432,10 @@ class FlashLlamaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, cross_attention_states, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -458,11 +447,9 @@ class FlashLlamaLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, - prefill_cache_indices, hpu_attention_meta=hpu_attention_meta, ) if self.residual_multiplier is not None: @@ -554,10 +541,8 @@ class FlashLlamaModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], adapter_data, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], cross_attention_states=None, @@ -577,12 +562,10 @@ class FlashLlamaModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, adapter_data, cross_attention_states, - prefill_cache_indices, hpu_attention_meta=hpu_attention_meta, ) @@ -643,30 +626,21 @@ class FlashLlamaForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, cross_attention_states=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if prefill_cache_indices is not None and slots.size( - 0 - ) != prefill_cache_indices.size(0): - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices=prefill_cache_indices, adapter_data=adapter_data, cross_attention_states=cross_attention_states, hpu_attention_meta=hpu_attention_meta, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py index 3bdfdd83..62e8470c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py @@ -169,11 +169,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, # Unused for this model @@ -276,11 +274,9 @@ class FlashLlavaNextForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=None, adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 38eba082..d23d4f67 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -178,10 +178,8 @@ class MistralAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ): @@ -198,14 +196,9 @@ class MistralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -220,7 +213,6 @@ class MistralAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) @@ -231,7 +223,6 @@ class MistralAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -335,10 +326,8 @@ class MistralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ): @@ -351,10 +340,8 @@ class MistralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ) @@ -403,10 +390,8 @@ class MistralModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], adapter_data: Optional[torch.Tensor] = None, ): @@ -424,10 +409,8 @@ class MistralModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ) @@ -475,30 +458,20 @@ class FlashMistralForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if prefill_cache_indices is not None and slots.size( - 0 - ) != prefill_cache_indices.size(0): - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, adapter_data, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index fbcb0970..1ef6be48 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -235,10 +235,8 @@ class MixtralAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -254,14 +252,9 @@ class MixtralAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -276,7 +269,6 @@ class MixtralAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) @@ -287,7 +279,6 @@ class MixtralAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -384,10 +375,8 @@ class MixtralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -399,10 +388,8 @@ class MixtralLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -454,10 +441,8 @@ class MixtralModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -475,10 +460,8 @@ class MixtralModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -510,29 +493,20 @@ class FlashMixtralForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if prefill_cache_indices is not None and slots.size( - 0 - ) != prefill_cache_indices.size(0): - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py index b26adad7..216642e0 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -801,12 +801,10 @@ class FlashLlamaCrossLayer(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, adapter_data, cross_attention_states, # [ IB, ...] - prefill_cache_indices, hpu_attention_meta, ) -> Tuple[torch.Tensor, torch.Tensor]: if cross_attention_states is None: @@ -911,11 +909,9 @@ class FlashMllamaForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], adapter_data: Optional[torch.Tensor] = None, # XXX: Putting these as optional so that the cuda warmup calls can go through. @@ -979,11 +975,9 @@ class FlashMllamaForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=prefill_cache_indices, lm_head_indices=lm_head_indices, adapter_data=adapter_data, cross_attention_states=cross_attention_states, diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index d1904c03..33f63333 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -147,10 +147,8 @@ class FlashNeoxAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -166,14 +164,10 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb(query_rot, key_rot, cos, sin) qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) - if prefill_cache_indices is not None: - qkv_to_cache = qkv[prefill_cache_indices] - else: - qkv_to_cache = qkv kv_cache.store( - key=qkv_to_cache[:, 1], - value=qkv_to_cache[:, 2], + key=qkv[:, 1], + value=qkv[:, 2], slots=slots, kv_scales=self.kv_scales, ) @@ -188,7 +182,6 @@ class FlashNeoxAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -198,7 +191,6 @@ class FlashNeoxAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -268,10 +260,8 @@ class FlashNeoXLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): if self.use_parallel_residual: @@ -283,10 +273,8 @@ class FlashNeoXLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -308,10 +296,8 @@ class FlashNeoXLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -363,10 +349,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) @@ -384,10 +368,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -417,11 +399,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -430,10 +410,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index 2b67501d..4d31d5dd 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -69,11 +69,9 @@ class PaliGemmaForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, # Unused here @@ -106,11 +104,9 @@ class PaliGemmaForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=None, adapter_data=adapter_data, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index cf7c9a79..0c777912 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -160,10 +160,8 @@ class FlashPhiAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): # Compute query, key, value and split @@ -190,13 +188,9 @@ class FlashPhiAttention(torch.nn.Module): ) # Reshape key and value and cache - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -210,7 +204,6 @@ class FlashPhiAttention(torch.nn.Module): kv_scales=self.kv_scales, kv_cache=kv_cache, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -220,7 +213,6 @@ class FlashPhiAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -287,10 +279,8 @@ class FlashPhiLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -301,10 +291,8 @@ class FlashPhiLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -354,10 +342,8 @@ class FlashPhiModel(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -375,10 +361,8 @@ class FlashPhiModel(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -409,10 +393,8 @@ class FlashPhiForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -422,10 +404,8 @@ class FlashPhiForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 480a17d1..af4b404d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -106,10 +106,8 @@ class Qwen2Attention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -125,14 +123,9 @@ class Qwen2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -147,7 +140,6 @@ class Qwen2Attention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) @@ -158,7 +150,6 @@ class Qwen2Attention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -230,10 +221,8 @@ class Qwen2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): normed_hidden_states, residual = self.input_layernorm(hidden_states) @@ -245,10 +234,8 @@ class Qwen2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) hidden_states = attn_output + residual @@ -296,10 +283,8 @@ class Qwen2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -317,10 +302,8 @@ class Qwen2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -364,21 +347,13 @@ class Qwen2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if prefill_cache_indices is not None and prefill_cache_indices.size( - 0 - ) != slots.size(0): - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -386,10 +361,8 @@ class Qwen2ForCausalLM(torch.nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index e7c4b2b6..141e13a6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -182,10 +182,8 @@ class FlashRWAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -203,14 +201,9 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -225,7 +218,6 @@ class FlashRWAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -235,7 +227,6 @@ class FlashRWAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -309,10 +300,8 @@ class FlashRWLargeAttention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -329,14 +318,9 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, :, 0].contiguous(), - value=kv_to_cache[:, :, 1].contiguous(), + key=kv[:, :, 0].contiguous(), + value=kv[:, :, 1].contiguous(), slots=slots, kv_scales=self.kv_scales, ) @@ -351,7 +335,6 @@ class FlashRWLargeAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -361,7 +344,6 @@ class FlashRWLargeAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -447,10 +429,8 @@ class FlashRWLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): if self.parallel_attn: @@ -462,10 +442,8 @@ class FlashRWLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -485,10 +463,8 @@ class FlashRWLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -573,10 +549,8 @@ class FlashRWLargeLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): # Layer norm. @@ -589,10 +563,8 @@ class FlashRWLargeLayer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -651,10 +623,8 @@ class FlashRWModel(FlashRWPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) @@ -672,10 +642,8 @@ class FlashRWModel(FlashRWPreTrainedModel): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -703,10 +671,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -716,10 +682,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index a41518d7..b68f4784 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -265,10 +265,8 @@ class FlashMQAttention(torch.nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): qkv = self.c_attn(hidden_states) @@ -282,14 +280,9 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) - if prefill_cache_indices is not None: - key_value_to_cache = key_value[prefill_cache_indices] - else: - key_value_to_cache = key_value - kv_cache.store( - key=key_value_to_cache[:, 0], - value=key_value_to_cache[:, 1], + key=key_value[:, 0], + value=key_value[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -304,7 +297,6 @@ class FlashMQAttention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode @@ -314,7 +306,6 @@ class FlashMQAttention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -379,10 +370,8 @@ class Block(nn.Module): residual, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ): hidden_states, residual = self.ln_1(hidden_states, residual) @@ -390,10 +379,8 @@ class Block(nn.Module): hidden_states, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -445,10 +432,8 @@ class FlashSantacoderModel(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -463,10 +448,8 @@ class FlashSantacoderModel(nn.Module): residual, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) @@ -496,11 +479,9 @@ class FlashSantacoderForCausalLM(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -509,10 +490,8 @@ class FlashSantacoderForCausalLM(nn.Module): position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, hpu_attention_meta, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 082e5d82..76f6f473 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -235,10 +235,8 @@ class Starcoder2Attention(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ): @@ -255,14 +253,9 @@ class Starcoder2Attention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv - kv_cache.store( - key=kv_to_cache[:, 0], - value=kv_to_cache[:, 1], + key=kv[:, 0], + value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -277,7 +270,6 @@ class Starcoder2Attention(torch.nn.Module): kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, - block_tables=block_tables, softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) @@ -288,7 +280,6 @@ class Starcoder2Attention(torch.nn.Module): kv_cache, self.kv_head_mapping, self.softmax_scale, - block_tables, seqlen, kv_scales=self.kv_scales, hpu_attention_meta=hpu_attention_meta, @@ -448,10 +439,8 @@ class Starcoder2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ): @@ -464,10 +453,8 @@ class Starcoder2Layer(nn.Module): sin, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ) @@ -518,10 +505,8 @@ class Starcoder2Model(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - prefill_cache_indices: Optional[torch.Tensor], adapter_data, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: @@ -540,10 +525,8 @@ class Starcoder2Model(torch.nn.Module): sin, cu_seqlen_prefill, kv_cache[i], - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ) @@ -589,29 +572,20 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if prefill_cache_indices is not None and slots.size( - 0 - ) != prefill_cache_indices.size(0): - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, kv_cache, - block_tables, slots, seqlen, - prefill_cache_indices, adapter_data, hpu_attention_meta, ) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py index 0a4305ec..02806ac9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py @@ -740,11 +740,9 @@ class Idefics2ForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, @@ -843,11 +841,9 @@ class Idefics2ForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=None, adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py index 9278a86a..964526fc 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py @@ -483,11 +483,9 @@ class Idefics3ForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, @@ -587,11 +585,9 @@ class Idefics3ForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=None, adapter_data=adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py index 75dd2b40..441b0016 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py @@ -906,11 +906,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], pixel_values: torch.FloatTensor = None, image_grid_thw: Optional[torch.LongTensor] = None, @@ -938,11 +936,9 @@ class Qwen2_5VLForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=prefill_cache_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py index 3b4965a2..47ae2ac9 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py @@ -480,11 +480,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], - block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, hpu_attention_meta: Optional[HPUPagedAttentionMetadata], - prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor], pixel_values: torch.FloatTensor = None, image_grid_thw: Optional[torch.LongTensor] = None, @@ -511,11 +509,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=block_tables, slots=slots, seqlen=seqlen, hpu_attention_meta=hpu_attention_meta, - prefill_cache_indices=prefill_cache_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index e032242c..b0859c3d 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -25,7 +25,7 @@ from typing import ( Dict, Union, ) - +import torch.nn.functional as F from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.models import Model @@ -116,7 +116,6 @@ def prepare_for_decode( block_list = flatten(block_tables) block_groups = flatten(block_groups) block_usage = flatten(block_usage) - assert len(block_list) == len(block_groups) assert len(block_list) == len(block_usage) if use_contiguous_pa: @@ -979,29 +978,27 @@ class FlashCausalLMBatch(Batch): # padding to left to work with sliding window # use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate # the right logit position - input_ids_padded = None - input_ids_padded_length = None + input_ids_padded_length = [] + # need extra pad to match warmup seq + extra_pad = 0 if isinstance(self.input_ids, list) and len(self) > 1: - input_ids_padded = [] input_ids_padded_length = [] + input_ids = [] for input_id in self.input_ids: - padded = self.max_input_length - len(input_id) - input_id_padded = input_id + padded = self.max_input_length - len(input_id) + extra_pad if padded > 0: - input_id_padded = [0] * padded + input_id_padded - input_ids_padded.append(input_id_padded) + input_id = [0] * padded + input_id + input_ids.append(input_id) input_ids_padded_length.append(padded) - input_ids_padded = np.concatenate(input_ids_padded, dtype=np.int64) - input_ids_padded = torch.tensor( - input_ids_padded, dtype=torch.int64, device=device - ) - - if isinstance(self.input_ids, list): - if len(self) > 1: - input_ids = np.concatenate(self.input_ids, dtype=np.int64) - else: - input_ids = self.input_ids[0] + input_ids = np.concatenate(input_ids, dtype=np.int64) self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + elif isinstance(self.input_ids, list): + input_ids = self.input_ids[0] + input_ids_padded_length.append(extra_pad) + input_ids = [0] * extra_pad + input_ids + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + else: + logger.error("should not be here, prefill self.input_ids is a tensor") self.input_lengths_tensor = torch.tensor( self.input_lengths, dtype=torch.int32, device=device @@ -1052,10 +1049,9 @@ class FlashCausalLMBatch(Batch): request_position_ids = torch.arange( cache_length, cache_length + input_length, dtype=torch.int32 ) - if input_ids_padded is not None: - position_ids.append( - torch.ones(input_ids_padded_length[i], dtype=torch.int32) - ) + request_position_ids = F.pad( + request_position_ids, (input_ids_padded_length[i], 0), value=1 + ) position_ids.append(request_position_ids) if not r.slots: @@ -1079,12 +1075,11 @@ class FlashCausalLMBatch(Batch): cumulative_slot_tokens += len(request_slots) # Create tensor to slice into the kv tensor in prefill - if input_ids_padded is not None: - # hpu need request_prefill_cache_indices to skip padding in kv cache - sliding_window = get_sliding_windows() - if sliding_window is None: - sliding_window = input_length - cumulative_length += input_ids_padded_length[i] + # hpu need request_prefill_cache_indices to skip padding in kv cache + sliding_window = get_sliding_windows() + if sliding_window is None: + sliding_window = input_length + cumulative_length += input_ids_padded_length[i] if sliding_window is not None: request_prefill_cache_indices = torch.arange( cumulative_length + max(0, input_length - sliding_window), @@ -1105,8 +1100,7 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - if sliding_window is not None: - prefill_cache_indices.append(request_prefill_cache_indices) + prefill_cache_indices.append(request_prefill_cache_indices) ADAPTER_TO_INDEX = get_adapter_to_index() if ADAPTER_TO_INDEX: @@ -1171,23 +1165,20 @@ class FlashCausalLMBatch(Batch): position_ids = torch.cat(position_ids) if slot_indices: slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) + prefill_cache_indices = torch.cat(prefill_cache_indices) else: if position_ids: position_ids = position_ids[0] if slot_indices: slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] + prefill_cache_indices = prefill_cache_indices[0] self.position_ids = position_ids.to(device) self.slot_indices = slot_indices.to(device) self.prefill_cu_outlens = prefill_cu_outlens - self.prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) + self.prefill_cache_indices = torch.zeros_like(self.input_ids, dtype=torch.bool) + self.prefill_cache_indices[prefill_cache_indices.to(device)] = True if all_prefill_logprobs: prefill_head_indices = None @@ -1203,21 +1194,19 @@ class FlashCausalLMBatch(Batch): self.prefill_head_indices = prefill_head_indices self.prefill_next_token_indices = prefill_next_token_indices - if input_ids_padded is not None: - self.input_ids = input_ids_padded - input_ids_padded_length_tensor = torch.cumsum( - torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device), - dim=-1, + input_ids_padded_length_tensor = torch.cumsum( + torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device), + dim=-1, + ) + if self.prefill_head_indices is not None: + self.prefill_head_indices = ( + self.prefill_head_indices + input_ids_padded_length_tensor ) - if self.prefill_head_indices is not None: - self.prefill_head_indices = ( - self.prefill_head_indices + input_ids_padded_length_tensor - ) - if self.prefill_next_token_indices is not None: - self.prefill_next_token_indices = ( - self.prefill_next_token_indices + input_ids_padded_length_tensor - ) + if self.prefill_next_token_indices is not None: + self.prefill_next_token_indices = ( + self.prefill_next_token_indices + input_ids_padded_length_tensor + ) if adapter_set: adapter_indices = torch.cat(adapter_indices_list).to( @@ -1232,7 +1221,6 @@ class FlashCausalLMBatch(Batch): adapter_segments = torch.tensor( adapter_segments, dtype=torch.int32, device=device ) - self.adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, adapter_set=adapter_set, @@ -1490,14 +1478,6 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) - for bs in [1, 2, 4, 8]: - for seqlen in [32, 64, 128, 256, 512, 1024]: - self.warmup_prefill(seqlen, bs) - - for bs in [1, 2, 4, 8]: - for block_num in [1, 2, 4, 8, 16]: - self.warmup_decode(bs, block_num * bs) - return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens def warmup_prefill(self, prompt_len: int, bs: int): @@ -1539,10 +1519,8 @@ class FlashCausalLM(Model): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, - block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, seqlen=trim_seqlen_metadata(seqlen), - prefill_cache_indices=None, lm_head_indices=lm_head_indices, adapter_data=None, hpu_attention_meta=None, @@ -1562,7 +1540,6 @@ class FlashCausalLM(Model): for i in range(bs): slots.append(BLOCK_SIZE * block_tables[i][-1] + BLOCK_SIZE - 1) slots = torch.tensor(slots, dtype=torch.int64, device=self.device) - input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) cache_lengths_tensor = ( torch.ones(bs, dtype=torch.int32, device=self.device) * past_len @@ -1575,11 +1552,11 @@ class FlashCausalLM(Model): cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) - block_num = cache_lengths_tensor // BLOCK_SIZE + 1 block_tables_valid = [] for i, bt in enumerate(block_tables.tolist()): block_tables_valid.append(bt[0 : block_num[i]]) + hpu_attention_meta = prepare_for_decode( self.dtype, self.use_contiguous_pa, @@ -1595,10 +1572,8 @@ class FlashCausalLM(Model): position_ids=position_ids, cu_seqlen_prefill=None, kv_cache=self.kv_cache, - block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, seqlen=trim_seqlen_metadata(seqlen), - prefill_cache_indices=None, lm_head_indices=None, adapter_data=None, hpu_attention_meta=hpu_attention_meta, @@ -1684,26 +1659,23 @@ class FlashCausalLM(Model): kwargs = {} if htorch.utils.internal.is_lazy(): kwargs["bypass_hpu_graphs"] = False + if batch.prefill_cache_indices is not None: + slots_pad = torch.zeros_like(input_ids) + slots_pad[batch.prefill_cache_indices] = slots + slots = slots_pad logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=None, slots=slots, seqlen=trim_seqlen_metadata(seqlen), - prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, # TODO not support adapter now, need the add in the future adapter_data=None, hpu_attention_meta=batch.hpu_attn_meta, **kwargs, ) - if batch.prefill_cache_indices is not None: - batch.prefill_cache_indices = None - # fix following runtime error in graph replay - # RuntimeError: Neither storage attached to input tensor, not its view - htorch.core.mark_step() return logits, speculative_logits @tracer.start_as_current_span("generate_token") @@ -1801,7 +1773,14 @@ class FlashCausalLM(Model): # instantly become of shape [BATCH_SIZE] if prefill and finished_prefilling: indices = batch.cu_seqlen_prefill[1:] - 1 - batch.position_ids = batch.position_ids[indices] + # pad in left + if batch.prefill_cache_indices is not None: + batch.position_ids = batch.position_ids[batch.prefill_cache_indices][ + indices + ] + else: + batch.position_ids = batch.position_ids[indices] + batch.slot_indices = batch.slot_indices[indices] batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ indices diff --git a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py index 48bfce89..b5d93cbc 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py @@ -462,11 +462,9 @@ class FlashVlmCausalLM(FlashCausalLM): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, - prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, pixel_values=batch.pixel_values, pixel_attention_mask=batch.pixel_attention_mask, diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index 4471aab3..eabbe247 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -288,11 +288,9 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, - block_tables=None, # block_table is not used in hpu pageattn. remove it to avoid shape change in hpu graph slots=slots, seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, - prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, cross_attention_states=cross_attention_states, # TODO list