diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 1eb8c6c3..fe19180a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( @@ -264,7 +265,7 @@ class FlashCohereAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -296,12 +297,10 @@ class FlashCohereAttention(torch.nn.Module): # flash attention attn_output = attention( query, - key, - value, kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -313,7 +312,7 @@ class FlashCohereAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -388,7 +387,7 @@ class FlashCohereLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -402,7 +401,7 @@ class FlashCohereLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -454,7 +453,7 @@ class FlashCohereModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: torch.Tensor, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -477,7 +476,7 @@ class FlashCohereModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -518,7 +517,7 @@ class FlashCohereForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -531,7 +530,7 @@ class FlashCohereForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index fc0dca5b..b82b5473 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( FastLinear, @@ -309,7 +310,7 @@ class DbrxAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -335,12 +336,10 @@ class DbrxAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -352,7 +351,7 @@ class DbrxAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -389,7 +388,7 @@ class DbrxNormAttentionNorm(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): normed_hidden_states, res = self.norm_1(hidden_states, residual) @@ -403,7 +402,7 @@ class DbrxNormAttentionNorm(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -622,7 +621,7 @@ class DbrxLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): # Self Attention @@ -635,7 +634,7 @@ class DbrxLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -679,7 +678,7 @@ class DbrxModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -701,7 +700,7 @@ class DbrxModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -734,7 +733,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -747,7 +746,7 @@ class FlashDbrxForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index b25becd5..0585b40e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -29,8 +29,8 @@ from text_generation_server.layers.attention import ( attention, paged_attention, reshape_and_cache, + Seqlen, ) -from text_generation_server.layers.attention.common import Seqlen from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.import_utils import SYSTEM @@ -298,7 +298,7 @@ class DeepseekV2Attention(torch.nn.Module): kv_cache: Tuple[torch.Tensor, torch.Tensor], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: Seqlen, + seqlen: Seqlen, max_s: int, ): if self.q_lora_rank is None: @@ -363,12 +363,10 @@ class DeepseekV2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - key, - value, kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -380,7 +378,7 @@ class DeepseekV2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -666,7 +664,7 @@ class DeepseekV2Layer(nn.Module): kv_cache, block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: Seqlen, + seqlen: Seqlen, max_s: int, ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -680,7 +678,7 @@ class DeepseekV2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -729,7 +727,7 @@ class DeepseekV2Model(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -751,7 +749,7 @@ class DeepseekV2Model(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -781,7 +779,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -794,7 +792,7 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index faf0f325..d16e805f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -213,7 +214,7 @@ class FlashGemma2Attention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -236,12 +237,10 @@ class FlashGemma2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, causal=self.causal, window_size_left=self.window_size, @@ -256,7 +255,7 @@ class FlashGemma2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, softcap=self.softcap, ) @@ -343,7 +342,7 @@ class FlashGemma2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -357,7 +356,7 @@ class FlashGemma2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -408,7 +407,7 @@ class FlashGemma2Model(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -430,7 +429,7 @@ class FlashGemma2Model(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -477,7 +476,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -491,7 +490,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 33738a59..34be4cb8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -207,7 +208,7 @@ class FlashGemmaAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -230,12 +231,10 @@ class FlashGemmaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, causal=self.causal, ) @@ -248,7 +247,7 @@ class FlashGemmaAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -320,7 +319,7 @@ class FlashGemmaLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -334,7 +333,7 @@ class FlashGemmaLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -382,7 +381,7 @@ class FlashGemmaModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -404,7 +403,7 @@ class FlashGemmaModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -449,7 +448,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -463,7 +462,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index d30b5a0a..403fa908 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -213,7 +214,7 @@ class FlashGPT2Attention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): query, key, value = self.query_key_value(hidden_states).split( @@ -230,12 +231,10 @@ class FlashGPT2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - key, - value, kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -247,7 +246,7 @@ class FlashGPT2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -316,7 +315,7 @@ class FlashGPT2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): residual = hidden_states @@ -329,7 +328,7 @@ class FlashGPT2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -382,7 +381,7 @@ class FlashGPT2Model(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -398,7 +397,7 @@ class FlashGPT2Model(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -435,7 +434,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -451,7 +450,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index eb667384..35ab2791 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -29,6 +29,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -167,7 +168,7 @@ class FlashGPTJAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): query, key, value = self.query_key_value(hidden_states).split( @@ -192,10 +193,10 @@ class FlashGPTJAttention(torch.nn.Module): # flash attention attn_output = attention( query, - key, - value, - cu_seqlen_prefill, - max_s, + kv_cache[0], + kv_cache[1], + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -207,7 +208,7 @@ class FlashGPTJAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -268,7 +269,7 @@ class FlashGPTJLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -281,7 +282,7 @@ class FlashGPTJLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -328,7 +329,7 @@ class FlashGPTJModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: @@ -351,7 +352,7 @@ class FlashGPTJModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -382,7 +383,7 @@ class FlashGPTJForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -395,7 +396,7 @@ class FlashGPTJForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices=prefill_cache_indices, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 5a150267..30ca3faf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -31,6 +31,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -185,7 +186,7 @@ class MistralAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -217,12 +218,10 @@ class MistralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, window_size_left=self.max_past, ) @@ -235,7 +234,7 @@ class MistralAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -356,7 +355,7 @@ class MistralLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -372,7 +371,7 @@ class MistralLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -424,7 +423,7 @@ class MistralModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -448,7 +447,7 @@ class MistralModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -499,7 +498,7 @@ class FlashMistralForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -512,7 +511,7 @@ class FlashMistralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = input_lengths.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.max_past_tensor) inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -522,7 +521,7 @@ class FlashMistralForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s, prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index ad426ffe..c5d60af1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -35,6 +35,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( FastLinear, @@ -243,7 +244,7 @@ class MixtralAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -274,12 +275,10 @@ class MixtralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, window_size_left=self.max_past, ) @@ -292,7 +291,7 @@ class MixtralAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -498,7 +497,7 @@ class MixtralLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -513,7 +512,7 @@ class MixtralLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -568,7 +567,7 @@ class MixtralModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -592,7 +591,7 @@ class MixtralModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -627,7 +626,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -640,7 +639,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = input_lengths.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, @@ -649,7 +648,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s, prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b684e035..fda648f9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -31,6 +31,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -147,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -171,12 +172,10 @@ class FlashNeoxAttention(torch.nn.Module): # flash attention attn_output = attention( qkv[:, 0], - qkv[:, 1], - qkv[:, 2], kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -188,7 +187,7 @@ class FlashNeoxAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -258,7 +257,7 @@ class FlashNeoXLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): if self.use_parallel_residual: @@ -272,7 +271,7 @@ class FlashNeoXLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -296,7 +295,7 @@ class FlashNeoXLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -350,7 +349,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) @@ -372,7 +371,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -404,7 +403,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -417,7 +416,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index efe27c13..37adb8be 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -10,6 +10,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -159,7 +160,7 @@ class FlashPhiAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): # Compute query, key, value and split @@ -192,12 +193,10 @@ class FlashPhiAttention(torch.nn.Module): if cu_seqlen_prefill is not None: attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -209,7 +208,7 @@ class FlashPhiAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -276,7 +275,7 @@ class FlashPhiLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -289,7 +288,7 @@ class FlashPhiLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -341,7 +340,7 @@ class FlashPhiModel(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -363,7 +362,7 @@ class FlashPhiModel(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -396,7 +395,7 @@ class FlashPhiForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -409,7 +408,7 @@ class FlashPhiForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index c72a9b90..1c55dd91 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -19,6 +19,7 @@ from text_generation_server.layers.attention import ( attention, paged_attention, reshape_and_cache, + Seqlen, ) @@ -181,7 +182,7 @@ class FlashRWAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -206,12 +207,10 @@ class FlashRWAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -223,7 +222,7 @@ class FlashRWAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -296,7 +295,7 @@ class FlashRWLargeAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -343,7 +342,7 @@ class FlashRWLargeAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -429,7 +428,7 @@ class FlashRWLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): if self.parallel_attn: @@ -443,7 +442,7 @@ class FlashRWLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -465,7 +464,7 @@ class FlashRWLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -552,7 +551,7 @@ class FlashRWLargeLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): # Layer norm. @@ -567,7 +566,7 @@ class FlashRWLargeLayer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -628,7 +627,7 @@ class FlashRWModel(FlashRWPreTrainedModel): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) @@ -650,7 +649,7 @@ class FlashRWModel(FlashRWPreTrainedModel): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -680,7 +679,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -693,7 +692,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 109304be..19025c4c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -9,6 +9,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -268,7 +269,7 @@ class FlashMQAttention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.c_attn(hidden_states) @@ -291,12 +292,10 @@ class FlashMQAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(key_value, dim=1, index=0), - torch.select(key_value, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -308,7 +307,7 @@ class FlashMQAttention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -373,7 +372,7 @@ class Block(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): hidden_states, residual = self.ln_1(hidden_states, residual) @@ -383,7 +382,7 @@ class Block(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -437,7 +436,7 @@ class FlashSantacoderModel(nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -454,7 +453,7 @@ class FlashSantacoderModel(nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -486,7 +485,7 @@ class FlashSantacoderForCausalLM(nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -499,7 +498,7 @@ class FlashSantacoderForCausalLM(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 200d4ef0..2f9ecd0d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -209,7 +210,7 @@ class Starcoder2Attention(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -240,12 +241,10 @@ class Starcoder2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, window_size_left=self.max_past, ) @@ -258,7 +257,7 @@ class Starcoder2Attention(torch.nn.Module): self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -381,7 +380,7 @@ class Starcoder2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -396,7 +395,7 @@ class Starcoder2Layer(nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -449,7 +448,7 @@ class Starcoder2Model(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -473,7 +472,7 @@ class Starcoder2Model(torch.nn.Module): kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -521,7 +520,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -534,7 +533,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = input_lengths.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, @@ -543,7 +542,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module): kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s, prefill_cache_indices,