From a07e7437b6281fb104b15553a4696a2450a6a201 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 16 Mar 2025 22:37:34 -0700 Subject: [PATCH] enable all the model. not testet yet Signed-off-by: Wang, Yi A --- .../custom_modeling/flash_cohere_modeling.py | 38 ++++++++---- .../custom_modeling/flash_dbrx_modeling.py | 41 ++++++++----- .../flash_deepseek_v2_modeling.py | 38 +++++++----- .../flash_deepseek_v3_modeling.py | 37 ++++++++---- .../custom_modeling/flash_gemma2_modeling.py | 37 +++++++----- .../custom_modeling/flash_gemma_modeling.py | 36 +++++++---- .../custom_modeling/flash_gpt2_modeling.py | 34 +++++++---- .../custom_modeling/flash_gptj_modeling.py | 38 +++++++----- .../custom_modeling/flash_llama_modeling.py | 8 +-- .../custom_modeling/flash_mistral_modeling.py | 28 ++++----- .../custom_modeling/flash_mixtral_modeling.py | 28 ++++----- .../custom_modeling/flash_neox_modeling.py | 38 +++++++----- .../custom_modeling/flash_phi_modeling.py | 35 +++++++---- .../custom_modeling/flash_qwen2_modeling.py | 29 +++++---- .../custom_modeling/flash_rw_modeling.py | 59 ++++++++++++------- .../flash_santacoder_modeling.py | 32 ++++++---- .../flash_starcoder2_modeling.py | 28 ++++----- .../models/flash_causal_lm.py | 40 ++++++------- 18 files changed, 374 insertions(+), 250 deletions(-) diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 44df7964..8d32032d 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers import ( @@ -221,7 +222,8 @@ class FlashCohereAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, key, value = qkv.split( @@ -245,9 +247,16 @@ class FlashCohereAttention(torch.nn.Module): self.rotary_emb(query, key, cos, sin) + if prefill_cache_indices is not None: + key_to_cache = key[prefill_cache_indices] + value_to_cache = value[prefill_cache_indices] + else: + key_to_cache = key + value_to_cache = value + kv_cache.store( - key=key, - value=value, + key=key_to_cache, + value=value_to_cache, slots=slots, kv_scales=self.kv_scales, ) @@ -274,8 +283,8 @@ class FlashCohereAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -350,7 +359,8 @@ class FlashCohereLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -364,7 +374,8 @@ class FlashCohereLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) mlp_output = self.mlp(normed_hidden_states) @@ -416,15 +427,14 @@ class FlashCohereModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: torch.Tensor, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None @@ -439,7 +449,8 @@ class FlashCohereModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -480,8 +491,8 @@ class FlashCohereForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -493,7 +504,8 @@ class FlashCohereForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index ba86f579..c01bd1bc 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -27,6 +27,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( FastLinear, @@ -312,7 +313,8 @@ class DbrxAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) if self.clip_qkv is not None: @@ -329,10 +331,14 @@ class DbrxAttention(torch.nn.Module): kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv kv_cache.store( - key=kv[:, 0], - value=kv[:, 1], + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -359,8 +365,8 @@ class DbrxAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -397,7 +403,8 @@ class DbrxNormAttentionNorm(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, res = self.norm_1(hidden_states, residual) @@ -411,7 +418,8 @@ class DbrxNormAttentionNorm(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -631,7 +639,8 @@ class DbrxLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): # Self Attention attn_output, attn_res = self.attn( @@ -644,7 +653,8 @@ class DbrxLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) moe_output = self.moe(attn_output) @@ -688,15 +698,14 @@ class DbrxModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -710,7 +719,8 @@ class DbrxModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -743,8 +753,8 @@ class FlashDbrxForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -756,7 +766,8 @@ class FlashDbrxForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index e30510b4..3298a30a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm @@ -258,7 +259,8 @@ class DeepseekV2Attention(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: query = self.q_proj(hidden_states) @@ -314,10 +316,15 @@ class DeepseekV2Attention(torch.nn.Module): value = torch.nn.functional.pad( value, (0, self.head_pad_size - self.value_head_size), value=0 ) - + if prefill_cache_indices is not None: + key_to_cache = key[prefill_cache_indices] + value_to_cache = value[prefill_cache_indices] + else: + key_to_cache = key + value_to_cache = value kv_cache.store( - key=key, - value=value, + key=key_to_cache, + value=value_to_cache, slots=slots, kv_scales=self.kv_scales, ) @@ -344,8 +351,8 @@ class DeepseekV2Attention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) # Remove padding. @@ -508,7 +515,8 @@ class DeepseekV2Layer(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -522,7 +530,8 @@ class DeepseekV2Layer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -571,15 +580,14 @@ class DeepseekV2Model(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -593,7 +601,8 @@ class DeepseekV2Model(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -623,8 +632,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -636,7 +645,8 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py index 452fe3f2..736e0c9a 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm @@ -258,7 +259,8 @@ class DeepseekV3Attention(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): if self.q_lora_rank is None: query = self.q_proj(hidden_states) @@ -315,9 +317,15 @@ class DeepseekV3Attention(torch.nn.Module): value, (0, self.head_pad_size - self.value_head_size), value=0 ) + if prefill_cache_indices is not None: + key_to_cache = key[prefill_cache_indices] + value_to_cache = value[prefill_cache_indices] + else: + key_to_cache = key + value_to_cache = value kv_cache.store( - key=key, - value=value, + key=key_to_cache, + value=value_to_cache, slots=slots, kv_scales=self.kv_scales, ) @@ -344,8 +352,8 @@ class DeepseekV3Attention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) # Remove padding. @@ -517,7 +525,8 @@ class DeepseekV3Layer(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -531,7 +540,8 @@ class DeepseekV3Layer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -580,15 +590,14 @@ class DeepseekV3Model(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -602,7 +611,8 @@ class DeepseekV3Model(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -632,8 +642,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -645,7 +655,8 @@ class FlashDeepseekV3ForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index ebf1b80e..5b7adad1 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -237,8 +238,9 @@ class FlashGemma2Attention(torch.nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -252,10 +254,14 @@ class FlashGemma2Attention(torch.nn.Module): kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv kv_cache.store( - key=kv[:, 0], - value=kv[:, 1], + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -284,9 +290,9 @@ class FlashGemma2Attention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, softcap=self.softcap, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -399,8 +405,9 @@ class FlashGemma2Layer(nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, + prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -414,8 +421,9 @@ class FlashGemma2Layer(nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, + prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -467,16 +475,15 @@ class FlashGemma2Model(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - adapter_data: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor], + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -490,8 +497,9 @@ class FlashGemma2Model(torch.nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -538,8 +546,8 @@ class FlashGemma2ForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -552,8 +560,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, adapter_data, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index ad3be80e..d26184b6 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -209,7 +210,8 @@ class FlashGemmaAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -224,9 +226,14 @@ class FlashGemmaAttention(torch.nn.Module): self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + kv_cache.store( - key=kv[:, 0], - value=kv[:, 1], + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -254,8 +261,8 @@ class FlashGemmaAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -327,7 +334,8 @@ class FlashGemmaLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -341,7 +349,8 @@ class FlashGemmaLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -389,15 +398,14 @@ class FlashGemmaModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -411,7 +419,8 @@ class FlashGemmaModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -456,8 +465,8 @@ class FlashGemmaForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -470,7 +479,8 @@ class FlashGemmaForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 906b34c1..a6e0a7de 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -28,6 +28,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -215,7 +216,8 @@ class FlashGPT2Attention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): query, key, value = self.query_key_value(hidden_states).split( self.head_size * self.num_heads, dim=1 @@ -224,9 +226,16 @@ class FlashGPT2Attention(torch.nn.Module): key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) + if prefill_cache_indices is not None: + key_to_cache = key[prefill_cache_indices] + value_to_cache = value[prefill_cache_indices] + else: + key_to_cache = key + value_to_cache = value + kv_cache.store( - key=key, - value=value, + key=key_to_cache, + value=value_to_cache, slots=slots, kv_scales=self.kv_scales, ) @@ -253,8 +262,8 @@ class FlashGPT2Attention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -323,7 +332,8 @@ class FlashGPT2Layer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -336,7 +346,8 @@ class FlashGPT2Layer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states = attn_output + residual @@ -389,9 +400,8 @@ class FlashGPT2Model(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds @@ -405,7 +415,8 @@ class FlashGPT2Model(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states = self.norm(hidden_states) @@ -442,7 +453,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -458,9 +469,8 @@ class FlashGPT2ForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, - true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index c23aa07f..9229a453 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -158,7 +159,8 @@ class FlashGPTJAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): query, key, value = self.query_key_value(hidden_states).split( self.head_size * self.num_heads, dim=1 @@ -175,9 +177,16 @@ class FlashGPTJAttention(torch.nn.Module): else: self.rotary_emb(query, key, cos, sin) + if prefill_cache_indices is not None: + key_to_cache = key[prefill_cache_indices] + value_to_cache = value[prefill_cache_indices] + else: + key_to_cache = key + value_to_cache = value + kv_cache.store( - key=key, - value=value, + key=key_to_cache, + value=value_to_cache, slots=slots, kv_scales=self.kv_scales, ) @@ -204,8 +213,8 @@ class FlashGPTJAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -266,7 +275,8 @@ class FlashGPTJLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): hidden_states, residual = self.input_layernorm(hidden_states, residual) # Self Attention @@ -279,7 +289,8 @@ class FlashGPTJLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) feed_forward_hidden_states = self.mlp(hidden_states) @@ -326,16 +337,14 @@ class FlashGPTJModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.wte(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -349,7 +358,8 @@ class FlashGPTJModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -380,8 +390,8 @@ class FlashGPTJForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -393,8 +403,8 @@ class FlashGPTJForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices=prefill_cache_indices, + hpu_attention_meta=hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a118ace5..857e1757 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -206,7 +206,7 @@ class FlashLlamaAttention(torch.nn.Module): seqlen, adapter_data, prefill_cache_indices: Optional[torch.Tensor], - hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -447,7 +447,7 @@ class FlashLlamaLayer(nn.Module): adapter_data, cross_attention_states, prefill_cache_indices: Optional[torch.Tensor], - hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -559,8 +559,8 @@ class FlashLlamaModel(torch.nn.Module): seqlen: Seqlen, prefill_cache_indices: Optional[torch.Tensor], adapter_data, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], cross_attention_states=None, - hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -646,11 +646,11 @@ class FlashLlamaForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, cross_attention_states=None, - hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index a0116297..8214b6b7 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -31,6 +31,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -180,9 +181,9 @@ class MistralAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -232,8 +233,8 @@ class MistralAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -337,9 +338,9 @@ class MistralLayer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -353,9 +354,9 @@ class MistralLayer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ) # faster post attention rms norm @@ -405,17 +406,14 @@ class MistralModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], adapter_data: Optional[torch.Tensor] = None, ): hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -429,9 +427,9 @@ class MistralModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -480,13 +478,14 @@ class FlashMistralForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: + if prefill_cache_indices is not None and slots.size( + 0 + ) != prefill_cache_indices.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] elif self.max_past is not None: @@ -503,9 +502,8 @@ class FlashMistralForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, - true_max_s, prefill_cache_indices, + hpu_attention_meta, adapter_data, ) if lm_head_indices is not None: diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index a45dd1e6..18ffe060 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,6 +37,7 @@ from text_generation_server.layers.attention import ( Seqlen, attention, paged_attention, + HPUPagedAttentionMetadata, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.layernorm import FastRMSNorm @@ -237,8 +238,8 @@ class MixtralAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -288,8 +289,8 @@ class MixtralAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -386,8 +387,8 @@ class MixtralLayer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -401,8 +402,8 @@ class MixtralLayer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ) # faster post attention rms norm @@ -456,17 +457,14 @@ class MixtralModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -480,8 +478,8 @@ class MixtralModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -515,13 +513,14 @@ class FlashMixtralForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: + if prefill_cache_indices is not None and slots.size( + 0 + ) != prefill_cache_indices.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] elif self.max_past is not None: @@ -537,9 +536,8 @@ class FlashMixtralForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, - true_max_s, prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 2301b63c..76269f22 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -149,7 +150,8 @@ class FlashNeoxAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -164,10 +166,14 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb(query_rot, key_rot, cos, sin) qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) + if prefill_cache_indices is not None: + qkv_to_cache = qkv[prefill_cache_indices] + else: + qkv_to_cache = qkv kv_cache.store( - key=qkv[:, 1], - value=qkv[:, 2], + key=qkv_to_cache[:, 1], + value=qkv_to_cache[:, 2], slots=slots, kv_scales=self.kv_scales, ) @@ -194,8 +200,8 @@ class FlashNeoxAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -265,7 +271,8 @@ class FlashNeoXLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): if self.use_parallel_residual: ln1_hidden_states, _ = self.input_layernorm(hidden_states) @@ -279,7 +286,8 @@ class FlashNeoXLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) @@ -303,7 +311,8 @@ class FlashNeoXLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, residual = self.post_attention_layernorm( @@ -357,15 +366,14 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -379,7 +387,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.final_layer_norm(hidden_states, residual) @@ -411,7 +420,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -424,7 +433,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 7382a7cb..21c4bc71 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -10,6 +10,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -162,7 +163,8 @@ class FlashPhiAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): # Compute query, key, value and split qkv = self.query_key_value(hidden_states) @@ -188,9 +190,13 @@ class FlashPhiAttention(torch.nn.Module): ) # Reshape key and value and cache + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv kv_cache.store( - key=kv[:, 0], - value=kv[:, 1], + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -216,8 +222,8 @@ class FlashPhiAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -284,7 +290,8 @@ class FlashPhiLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): hidden_states, res = self.input_layernorm(hidden_states, residual) # Self Attention @@ -297,7 +304,8 @@ class FlashPhiLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states = self.resid_dropout(attn_output).add( @@ -349,15 +357,14 @@ class FlashPhiModel(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -371,7 +378,8 @@ class FlashPhiModel(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -404,8 +412,8 @@ class FlashPhiForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -417,7 +425,8 @@ class FlashPhiForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index d6569a1d..c62435fe 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -108,8 +109,8 @@ class Qwen2Attention(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( @@ -159,8 +160,8 @@ class Qwen2Attention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -232,8 +233,8 @@ class Qwen2Layer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ): normed_hidden_states, residual = self.input_layernorm(hidden_states) @@ -247,8 +248,8 @@ class Qwen2Layer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ) hidden_states = attn_output + residual @@ -298,16 +299,13 @@ class Qwen2Model(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = inputs_embeds cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids, - true_max_s, - hidden_states.dtype, ) residual = None @@ -322,8 +320,8 @@ class Qwen2Model(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states) @@ -369,13 +367,15 @@ class Qwen2ForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - prefill_cache_indices: Optional[torch.Tensor] = None, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: + + if prefill_cache_indices is not None and prefill_cache_indices.size( + 0 + ) != slots.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] elif self.max_past is not None: @@ -393,9 +393,8 @@ class Qwen2ForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, - true_max_s, prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index fbf1a597..c6034bf0 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -19,6 +19,7 @@ from text_generation_server.layers.attention import ( attention, paged_attention, Seqlen, + HPUPagedAttentionMetadata, ) @@ -184,7 +185,8 @@ class FlashRWAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) @@ -201,9 +203,14 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + kv_cache.store( - key=kv[:, 0], - value=kv[:, 1], + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -230,8 +237,8 @@ class FlashRWAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -305,7 +312,8 @@ class FlashRWLargeAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) @@ -321,9 +329,14 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] + else: + kv_to_cache = kv + kv_cache.store( - key=kv[:, :, 0].contiguous(), - value=kv[:, :, 1].contiguous(), + key=kv_to_cache[:, :, 0].contiguous(), + value=kv_to_cache[:, :, 1].contiguous(), slots=slots, kv_scales=self.kv_scales, ) @@ -350,8 +363,8 @@ class FlashRWLargeAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.dense( @@ -437,7 +450,8 @@ class FlashRWLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -451,7 +465,8 @@ class FlashRWLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) mlp_output = self.mlp(ln_hidden_states) @@ -473,7 +488,8 @@ class FlashRWLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if self.post_attention_layernorm is not None: @@ -560,7 +576,8 @@ class FlashRWLargeLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): # Layer norm. ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual) @@ -575,7 +592,8 @@ class FlashRWLargeLayer(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) # MLP. @@ -636,15 +654,14 @@ class FlashRWModel(FlashRWPreTrainedModel): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( - position_ids, max_s, hidden_states.dtype - ) + cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.h): @@ -658,7 +675,8 @@ class FlashRWModel(FlashRWPreTrainedModel): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -688,8 +706,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -701,7 +719,8 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index ed053eb6..9b24e8ba 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -271,7 +272,8 @@ class FlashMQAttention(torch.nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): qkv = self.c_attn(hidden_states) @@ -284,9 +286,14 @@ class FlashMQAttention(torch.nn.Module): query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) + if prefill_cache_indices is not None: + key_value_to_cache = key_value[prefill_cache_indices] + else: + key_value_to_cache = key_value + kv_cache.store( - key=key_value[:, 0], - value=key_value[:, 1], + key=key_value_to_cache[:, 0], + value=key_value_to_cache[:, 1], slots=slots, kv_scales=self.kv_scales, ) @@ -313,8 +320,8 @@ class FlashMQAttention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) @@ -379,7 +386,8 @@ class Block(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ): hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.self_attn( @@ -389,7 +397,8 @@ class Block(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, residual = self.ln_2(hidden_states, residual) @@ -443,7 +452,8 @@ class FlashSantacoderModel(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -460,7 +470,8 @@ class FlashSantacoderModel(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -492,7 +503,7 @@ class FlashSantacoderForCausalLM(nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, @@ -505,7 +516,8 @@ class FlashSantacoderForCausalLM(nn.Module): block_tables, slots, seqlen, - max_s, + prefill_cache_indices, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 5e090369..d12bee5c 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, + HPUPagedAttentionMetadata, ) from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, @@ -237,9 +238,9 @@ class Starcoder2Attention(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( @@ -289,8 +290,8 @@ class Starcoder2Attention(torch.nn.Module): self.softmax_scale, block_tables, seqlen, - max_s, kv_scales=self.kv_scales, + hpu_attention_meta=hpu_attention_meta, ) return self.o_proj( @@ -450,9 +451,9 @@ class Starcoder2Layer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -466,9 +467,9 @@ class Starcoder2Layer(nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ) # faster post attention rms norm @@ -520,18 +521,15 @@ class Starcoder2Model(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, - true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], adapter_data, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward # Avoid to index in each layer - cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - position_ids, true_max_s, hidden_states.dtype - ) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids) residual = None for i, layer in enumerate(self.layers): @@ -545,9 +543,9 @@ class Starcoder2Model(torch.nn.Module): block_tables, slots, seqlen, - max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -594,13 +592,14 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, - max_s: int, + hpu_attention_meta: Optional[HPUPagedAttentionMetadata], prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - true_max_s = max_s - if prefill_cache_indices is not None: + if prefill_cache_indices is not None and slots.size( + 0 + ) != prefill_cache_indices.size(0): # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] elif self.max_past is not None: @@ -616,10 +615,9 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): block_tables, slots, seqlen, - max_s, - true_max_s, prefill_cache_indices, adapter_data, + hpu_attention_meta, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 49313c83..27e1c672 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1009,24 +1009,22 @@ class FlashCausalLMBatch(Batch): # padding to left to work with sliding window # use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate # the right logit position - - if device.type == "hpu": - input_ids_padded = None - input_ids_padded_length = None - if isinstance(self.input_ids, list) and len(self) > 1: - input_ids_padded = [] - input_ids_padded_length = [] - for input_id in self.input_ids: - padded = self.max_input_length - len(input_id) - input_id_padded = input_id - if padded > 0: - input_id_padded = [0] * padded + input_id_padded - input_ids_padded.append(input_id_padded) - input_ids_padded_length.append(padded) - input_ids_padded = np.concatenate(input_ids_padded, dtype=np.int64) - input_ids_padded = torch.tensor( - input_ids_padded, dtype=torch.int64, device=device - ) + input_ids_padded = None + input_ids_padded_length = None + if isinstance(self.input_ids, list) and len(self) > 1: + input_ids_padded = [] + input_ids_padded_length = [] + for input_id in self.input_ids: + padded = self.max_input_length - len(input_id) + input_id_padded = input_id + if padded > 0: + input_id_padded = [0] * padded + input_id_padded + input_ids_padded.append(input_id_padded) + input_ids_padded_length.append(padded) + input_ids_padded = np.concatenate(input_ids_padded, dtype=np.int64) + input_ids_padded = torch.tensor( + input_ids_padded, dtype=torch.int64, device=device + ) if isinstance(self.input_ids, list): if len(self) > 1: @@ -1084,7 +1082,7 @@ class FlashCausalLMBatch(Batch): request_position_ids = torch.arange( cache_length, cache_length + input_length, dtype=torch.int32 ) - if device.type == "hpu" and input_ids_padded is not None: + if input_ids_padded is not None: position_ids.append( torch.ones(input_ids_padded_length[i], dtype=torch.int32) ) @@ -1111,7 +1109,7 @@ class FlashCausalLMBatch(Batch): cumulative_slot_tokens += len(request_slots) # Create tensor to slice into the kv tensor in prefill - if device.type == "hpu" and input_ids_padded is not None: + if input_ids_padded is not None: # hpu need request_prefill_cache_indices to skip padding in kv cache sliding_window = get_sliding_windows() if sliding_window is None: @@ -1235,7 +1233,7 @@ class FlashCausalLMBatch(Batch): self.prefill_head_indices = prefill_head_indices self.prefill_next_token_indices = prefill_next_token_indices - if device.type == "hpu" and input_ids_padded is not None: + if input_ids_padded is not None: self.input_ids = input_ids_padded input_ids_padded_length_tensor = torch.cumsum( torch.tensor(input_ids_padded_length, dtype=torch.int64, device=device),