From 5ce89059f8149eaf313c63e9ded4199670cd74bb Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 12 Jun 2023 18:30:29 +0200 Subject: [PATCH] feat(server): pre-allocate past key values for flash causal LM (#412) --- server/Makefile-flash-att | 4 +- .../custom_modeling/flash_llama_modeling.py | 114 ++++--- .../custom_modeling/flash_neox_modeling.py | 121 +++++--- .../custom_modeling/flash_rw_modeling.py | 182 +++++++---- .../flash_santacoder_modeling.py | 125 ++++---- .../models/flash_causal_lm.py | 293 +++++++++--------- 6 files changed, 494 insertions(+), 345 deletions(-) diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att index ad894bfa..0e67a9e4 100644 --- a/server/Makefile-flash-att +++ b/server/Makefile-flash-att @@ -1,9 +1,9 @@ -flash_att_commit := d478eeec8f16c7939c54e4617dbd36f59b8eeed7 +flash_att_commit := 06ece1a1525ebcf4e183ac76b1e5108d2872f57f flash-attention: # Clone flash attention pip install packaging - git clone https://github.com/HazyResearch/flash-attention.git + git clone https://github.com/OlivierDehaene/flash-attention.git build-flash-attention: flash-attention cd flash-attention && git fetch && git checkout $(flash_att_commit) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 8a35ffa8..993e1e2a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -128,11 +128,14 @@ class FlashLlamaAttention(torch.nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -142,7 +145,7 @@ class FlashLlamaAttention(torch.nn.Module): self.rotary_emb(qkv[:, 1], cos, sin) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past layer_past[...] = qkv[:, 1:] @@ -154,8 +157,10 @@ class FlashLlamaAttention(torch.nn.Module): qkv[:, 1], qkv[:, 2], attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -170,7 +175,7 @@ class FlashLlamaAttention(torch.nn.Module): else: query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = qkv[:, 1:] + layer_past[past_present_indices] = qkv[:, 1:] # output attn_output = torch.empty_like(query) @@ -180,8 +185,10 @@ class FlashLlamaAttention(torch.nn.Module): layer_past[:, 0], layer_past[:, 1], attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -258,11 +265,14 @@ class FlashLlamaLayer(nn.Module): residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -271,11 +281,14 @@ class FlashLlamaLayer(nn.Module): normed_hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) # faster post attention rms norm @@ -322,35 +335,37 @@ class FlashLlamaModel(torch.nn.Module): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - past_key_values: Optional[torch.Tensor] = None, + past_present_indices, + past_key_values=None, pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.embed_tokens(input_ids) # Prefill if past_key_values is None: + assert pre_allocate_past_size is not None + + prefill = True + # Create past tensor + # We create a tensor of the same size as input_ids as we don't want to slice at every layer past_key_values = hidden_states.new_empty( ( + len(input_ids), len(self.layers), - len(hidden_states) - if pre_allocate_past_size is None - else pre_allocate_past_size, 2, self.num_heads, self.head_size, ) ) - layer_past_present_indices = None - slice_past_index = len(hidden_states) # Decode else: - # Create indices from cumulative sequence lengths - layer_past_present_indices = cu_seqlens[1:] - 1 - slice_past_index = None + prefill = False # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -360,25 +375,36 @@ class FlashLlamaModel(torch.nn.Module): residual = None for i, layer in enumerate(self.layers): - # We added padding that we now need to slice - layer_past_key_values = ( - past_key_values[i] - if slice_past_index is None - else past_key_values[i, :slice_past_index] - ) - hidden_states, residual = layer( hidden_states, residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - layer_past_key_values, - layer_past_present_indices, - cu_seqlens_q, + past_key_values[:, i], + past_present_indices, + prefill, ) + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + ( + pre_allocate_past_size, + len(self.layers), + 2, + self.num_heads, + self.head_size, + ) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states, past_key_values @@ -399,9 +425,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -409,9 +438,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): hidden_states, present = self.model( input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values, pre_allocate_past_size, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index c045f16e..3586b85a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -113,11 +113,14 @@ class FlashNeoxAttention(torch.nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) @@ -127,7 +130,7 @@ class FlashNeoxAttention(torch.nn.Module): self.rotary_emb(qkv[:, 1], cos, sin) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past layer_past[...] = qkv[:, 1:] @@ -139,8 +142,10 @@ class FlashNeoxAttention(torch.nn.Module): qkv[:, 1], qkv[:, 2], attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -155,7 +160,7 @@ class FlashNeoxAttention(torch.nn.Module): else: query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = qkv[:, 1:] + layer_past[past_present_indices] = qkv[:, 1:] # output attn_output = torch.empty_like(query) @@ -165,8 +170,10 @@ class FlashNeoxAttention(torch.nn.Module): layer_past[:, 0], layer_past[:, 1], attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -240,11 +247,14 @@ class FlashNeoXLayer(nn.Module): residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): if self.use_parallel_residual: ln1_hidden_states, _ = self.input_layernorm(hidden_states) @@ -253,11 +263,14 @@ class FlashNeoXLayer(nn.Module): ln1_hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) @@ -276,11 +289,14 @@ class FlashNeoXLayer(nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) hidden_states, residual = self.post_attention_layernorm( @@ -329,9 +345,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values=None, pre_allocate_past_size: Optional[int] = None, ): @@ -339,25 +358,24 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): # Prefill if past_key_values is None: + assert pre_allocate_past_size is not None + + prefill = True + # Create past tensor + # We create a tensor of the same size as input_ids as we don't want to slice at every layer past_key_values = hidden_states.new_empty( ( + len(input_ids), len(self.layers), - len(hidden_states) - if pre_allocate_past_size is None - else pre_allocate_past_size, 2, self.num_heads, self.head_size, ) ) - layer_past_present_indices = None - slice_past_index = len(hidden_states) # Decode else: - # Create indices from cumulative sequence lengths - layer_past_present_indices = cu_seqlens[1:] - 1 - slice_past_index = None + prefill = False # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -367,25 +385,36 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): residual = None for i, layer in enumerate(self.layers): - # We added padding that we now need to slice - layer_past_key_values = ( - past_key_values[i] - if slice_past_index is None - else past_key_values[i, :slice_past_index] - ) - hidden_states, residual = layer( hidden_states, residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - layer_past_key_values, - layer_past_present_indices, - cu_seqlens_q, + past_key_values[:, i], + past_present_indices, + prefill, ) + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + ( + pre_allocate_past_size, + len(self.layers), + 2, + self.num_heads, + self.head_size, + ) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + hidden_states, _ = self.final_layer_norm(hidden_states, residual) return hidden_states, past_key_values @@ -404,9 +433,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -414,9 +446,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): hidden_states, present = self.gpt_neox( input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values, pre_allocate_past_size, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index af9fa548..4a9063eb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -130,11 +130,14 @@ class FlashRWAttention(torch.nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.query_key_value(hidden_states) @@ -150,10 +153,10 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, cos, sin) - self.rotary_emb(kv[:, 0], cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past layer_past[...] = kv # Expand to query shape @@ -164,11 +167,13 @@ class FlashRWAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - kv[:, 0], - kv[:, 1], + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -182,7 +187,7 @@ class FlashRWAttention(torch.nn.Module): # Decode else: # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = kv + layer_past[past_present_indices] = kv # Expand to query shape kv = layer_past.expand(-1, 2, self.num_heads, self.head_size) @@ -191,11 +196,13 @@ class FlashRWAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - kv[:, 0], - kv[:, 1], + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -261,11 +268,14 @@ class FlashRWLargeAttention(torch.nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) @@ -280,10 +290,10 @@ class FlashRWLargeAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, cos, sin) - self.rotary_emb(kv[:, :, 0], cos, sin) + self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past layer_past[...] = kv # Expand to query shape @@ -298,11 +308,13 @@ class FlashRWLargeAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - kv[:, :, 0], - kv[:, :, 1], + torch.select(kv, dim=2, index=0), + torch.select(kv, dim=2, index=1), attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -316,7 +328,7 @@ class FlashRWLargeAttention(torch.nn.Module): # Decode else: # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = kv + layer_past[past_present_indices] = kv # Expand to query shape kv = ( layer_past.unsqueeze(2) @@ -329,11 +341,13 @@ class FlashRWLargeAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - kv[:, :, 0], - kv[:, :, 1], + torch.select(kv, dim=2, index=0), + torch.select(kv, dim=2, index=1), attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -417,11 +431,14 @@ class FlashRWLayer(nn.Module): residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -430,11 +447,14 @@ class FlashRWLayer(nn.Module): ln_hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) mlp_output = self.mlp(ln_hidden_states) @@ -451,11 +471,14 @@ class FlashRWLayer(nn.Module): hidden_states, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) hidden_states, residual = self.post_attention_layernorm( @@ -499,11 +522,14 @@ class FlashRWLargeLayer(nn.Module): residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): ln_attn, residual = self.ln_attn(hidden_states, residual) ln_mlp, _ = self.ln_mlp(residual) @@ -513,11 +539,14 @@ class FlashRWLargeLayer(nn.Module): ln_attn, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) # MLP. @@ -584,9 +613,12 @@ class FlashRWModel(FlashRWPreTrainedModel): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values=None, pre_allocate_past_size: Optional[int] = None, ): @@ -594,23 +626,22 @@ class FlashRWModel(FlashRWPreTrainedModel): # Prefill if past_key_values is None: + assert pre_allocate_past_size is not None + + prefill = True + # Create past tensor + # We create a tensor of the same size as input_ids as we don't want to slice at every layer past_key_values = hidden_states.new_empty( ( + len(input_ids), len(self.h), - len(hidden_states) - if pre_allocate_past_size is None - else pre_allocate_past_size, *self.cache_size, ) ) - layer_past_present_indices = None - slice_past_index = len(hidden_states) # Decode else: - # Create indices from cumulative sequence lengths - layer_past_present_indices = cu_seqlens[1:] - 1 - slice_past_index = None + prefill = False # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -620,25 +651,34 @@ class FlashRWModel(FlashRWPreTrainedModel): residual = None for i, layer in enumerate(self.h): - # We added padding that we now need to slice - layer_past_key_values = ( - past_key_values[i] - if slice_past_index is None - else past_key_values[i, :slice_past_index] - ) - hidden_states, residual = layer( hidden_states, residual, cos, sin, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - layer_past_key_values, - layer_past_present_indices, - cu_seqlens_q, + torch.select(past_key_values, dim=1, index=i), + past_present_indices, + prefill, ) + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + ( + pre_allocate_past_size, + len(self.h), + *self.cache_size, + ) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states, past_key_values @@ -658,9 +698,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -668,9 +711,12 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): hidden_states, present = self.transformer( input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values, pre_allocate_past_size, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index fcf6be68..00cc47b6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -7,6 +7,7 @@ from typing import Optional # Flash attention imports import flash_attn_cuda + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module): def forward( self, hidden_states, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.c_attn(hidden_states) @@ -166,7 +170,7 @@ class FlashMQAttention(torch.nn.Module): key_value = key_value.view(-1, 2, 1, self.head_size) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past layer_past[...] = key_value # Expand from 1 to num_heads @@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - key_value[:, 0], - key_value[:, 1], + torch.select(key_value, dim=1, index=0), + torch.select(key_value, dim=1, index=1), attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module): # Decode else: # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = key_value + layer_past[past_present_indices] = key_value # Expand from 1 to num_heads key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size) @@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module): # flash attention flash_attn_cuda.fwd( query, - key_value[:, 0], - key_value[:, 1], + torch.select(key_value, dim=1, index=0), + torch.select(key_value, dim=1, index=1), attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -277,21 +285,27 @@ class Block(nn.Module): self, hidden_states, residual, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.attn( hidden_states, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) hidden_states, residual = self.ln_2(hidden_states, residual) @@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - past_key_values: Optional[torch.Tensor] = None, + past_present_indices, + past_key_values=None, pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -352,45 +369,43 @@ class FlashSantacoderModel(nn.Module): # Prefill if past_key_values is None: + assert pre_allocate_past_size is not None + + prefill = True + # Create past tensor - past_key_values = hidden_states.new_empty( - ( - len(self.h), - len(hidden_states) - if pre_allocate_past_size is None - else pre_allocate_past_size, - 2, - 1, - self.head_size, - ) + # We create a tensor of the same size as input_ids as we don't want to slice at every layer + past_key_values = hidden_states.new_zeros( + (len(input_ids), len(self.h), 2, 1, self.head_size) ) - layer_past_present_indices = None - slice_past_index = len(hidden_states) # Decode else: - # Create indices from cumulative sequence lengths - layer_past_present_indices = cu_seqlens[1:] - 1 - slice_past_index = None + prefill = False residual = None for i, layer in enumerate(self.h): - # We added padding that we now need to slice - layer_past_key_values = ( - past_key_values[i] - if slice_past_index is None - else past_key_values[i, :slice_past_index] - ) - hidden_states, residual = layer( hidden_states, residual, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - layer_past_key_values, - layer_past_present_indices, - cu_seqlens_q, + torch.select(past_key_values, dim=1, index=i), + past_present_indices, + prefill, ) + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + (pre_allocate_past_size, len(self.h), 2, 1, self.head_size) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states, past_key_values @@ -408,9 +423,12 @@ class FlashSantacoderForCausalLM(nn.Module): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -418,9 +436,12 @@ class FlashSantacoderForCausalLM(nn.Module): hidden_states, present = self.transformer( input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values, pre_allocate_past_size, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a2ad2d5e..ecea998e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -3,8 +3,6 @@ import torch.distributed import numpy as np -from torch.nn import functional as F - from dataclasses import dataclass from opentelemetry import trace from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel @@ -34,10 +32,21 @@ class FlashCausalLMBatch(Batch): input_ids: torch.Tensor position_ids: torch.Tensor - # cumulative sequence lengths - cu_seqlens: torch.Tensor - # cumulative query sequence lengths, only used in decode - cu_seqlens_q: Optional[torch.Tensor] + # Indices to copy present to the correct indices is the pre-allocated past key values + past_present_indices: torch.Tensor + + # tensor of length b holding starting offset of each sequence + start_seq: torch.Tensor + # tensor of length b holding ending offset of each sequence + end_seq: torch.Tensor + # tensor of length b holding starting offset of each sequence, only used in prefill + start_seq_prefill: Optional[torch.Tensor] + # tensor of length b holding ending offset of each sequence, only used in prefill + end_seq_prefill: Optional[torch.Tensor] + # tensor of length b holding starting offset of each query sequence, only used in decode + start_seq_q: Optional[torch.Tensor] + # tensor of length b holding ending offset of each query sequence, only used in decode + end_seq_q: Optional[torch.Tensor] # past key values, only used in decode past_key_values: Optional[torch.Tensor] max_seqlen: int @@ -90,7 +99,11 @@ class FlashCausalLMBatch(Batch): )["input_ids"] position_ids = [] - cu_seqlens = [0] + past_present_indices = [] + start_seq = [] + end_seq = [] + start_seq_prefill = [] + end_seq_prefill = [] max_seqlen = 0 input_lengths = [] @@ -110,9 +123,9 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_length = 0 + cumulative_max_length = 0 prefill_out_cumulative_length = 0 - max_tokens = 0 max_length = 0 # Parse batch @@ -138,7 +151,10 @@ class FlashCausalLMBatch(Batch): position_ids.append(request_position_ids) # Add cumulative lengths of all previous inputs - cu_seqlens.append(cumulative_length + input_length) + start_seq_prefill.append(cumulative_length) + end_seq_prefill.append(cumulative_length + input_length) + start_seq.append(cumulative_max_length) + end_seq.append(cumulative_max_length + input_length) next_token_chooser_parameters.append(r.parameters) @@ -168,9 +184,17 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 + request_past_present_indices = torch.arange( + cumulative_max_length, + cumulative_max_length + input_length, + dtype=torch.int64, + ) + past_present_indices.append(request_past_present_indices) + # Update + # Remove one as the first token des not have a past cumulative_length += input_length - max_tokens += input_length + max_new_tokens + cumulative_max_length += input_length + max_new_tokens - 1 max_length = max(max_length, input_length + max_new_tokens) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( @@ -184,26 +208,45 @@ class FlashCausalLMBatch(Batch): for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids + # Create tensors on device + all_input_ids_tensor = torch.tensor( + all_input_ids_tensor, dtype=torch.int64, device=device + ) + start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32) + end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32) + if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) + + past_present_indices = np.concatenate(past_present_indices, dtype=np.int64) + + start_seq_prefill = torch.tensor( + start_seq_prefill, device=device, dtype=torch.int32 + ) + end_seq_prefill = torch.tensor( + end_seq_prefill, device=device, dtype=torch.int32 + ) else: input_ids = all_input_ids[0] position_ids = position_ids[0] - # Create tensors on device + past_present_indices = past_present_indices[0] + + start_seq_prefill = start_seq + end_seq_prefill = end_seq + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) - cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) + past_present_indices = torch.tensor( + past_present_indices, device=device, dtype=torch.int64 + ) if all_prefill_logprobs: prefill_head_indices = None - prefill_next_token_indices = cu_seqlens[1:] - 1 + prefill_next_token_indices = end_seq_prefill - 1 elif no_prefill_logprobs: - prefill_head_indices = cu_seqlens[1:] - 1 + prefill_head_indices = end_seq_prefill - 1 prefill_next_token_indices = None else: prefill_head_indices = torch.tensor( @@ -219,8 +262,13 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - cu_seqlens=cu_seqlens, - cu_seqlens_q=None, + past_present_indices=past_present_indices, + start_seq=start_seq, + end_seq=end_seq, + start_seq_prefill=start_seq_prefill, + end_seq_prefill=end_seq_prefill, + start_seq_q=None, + end_seq_q=None, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -233,7 +281,7 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=max_tokens, + max_tokens=cumulative_max_length, ) @tracer.start_as_current_span("filter") @@ -244,10 +292,10 @@ class FlashCausalLMBatch(Batch): if len(request_ids) == len(self): return self - single_request = len(request_ids) == 1 + device = self.input_ids.device # Cumulative length - cumulative_length = 0 + cumulative_max_length = 0 # New values after filtering requests_idx_mapping = {} @@ -255,11 +303,17 @@ class FlashCausalLMBatch(Batch): # Used to index into tensors indices = [] + # past indices to keep + past_indices = torch.zeros( + self.past_key_values.shape[0], dtype=torch.bool, device=device + ) + # Create on CPU to only move to GPU once instead of at every copy - cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32) - cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1] + start_seq = torch.empty(len(request_ids), dtype=torch.int32) + end_seq = torch.empty(len(request_ids), dtype=torch.int32) + start_seq_q = self.start_seq_q[: len(request_ids)] + end_seq_q = self.end_seq_q[: len(request_ids)] max_seqlen = 0 - past_key_values = [] requests = [] all_input_ids = [] @@ -270,8 +324,6 @@ class FlashCausalLMBatch(Batch): stopping_criterias = [] - max_tokens = 0 - for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] indices.append(idx) @@ -281,16 +333,8 @@ class FlashCausalLMBatch(Batch): # Get length request_input_length = self.input_lengths[idx] - - # Copy to tensor (CPU) - cu_seqlens[i + 1] = cumulative_length + request_input_length max_seqlen = max(max_seqlen, request_input_length) - # Slice from past - past_key_values.append( - self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]] - ) - all_input_ids.append(self.all_input_ids[idx]) input_lengths.append(request_input_length) @@ -300,39 +344,32 @@ class FlashCausalLMBatch(Batch): stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) - cumulative_length += request_input_length - max_tokens += request_input_length + ( + remaining_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) - if single_request: - # Preallocate tensor for bs = 1 case - past_key_values = F.pad( - past_key_values[0], - ( - 0, - 0, - 0, - 0, - 0, - 0, - 0, - stopping_criterias[0].max_new_tokens - - stopping_criterias[0].current_tokens, - ), - ) - else: - # Cat all past - past_key_values = torch.cat(past_key_values, dim=1) + # Copy to tensor (CPU) + start_seq[i] = cumulative_max_length + end_seq[i] = cumulative_max_length + request_input_length + + # Set slice + past_indices[ + self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1 + ] = True + + cumulative_max_length += request_input_length + remaining_tokens - 1 # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) + past_key_values = self.past_key_values[past_indices] # Move to GPU now that we have the whole tensor - cu_seqlens = cu_seqlens.to(self.cu_seqlens.device) + start_seq = start_seq.to(device) + end_seq = end_seq.to(device) + past_present_indices = end_seq - 1 return FlashCausalLMBatch( batch_id=self.batch_id, @@ -340,8 +377,13 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - cu_seqlens=cu_seqlens, - cu_seqlens_q=cu_seqlens_q, + past_present_indices=past_present_indices, + start_seq=start_seq, + end_seq=end_seq, + start_seq_prefill=None, + end_seq_prefill=None, + start_seq_q=start_seq_q, + end_seq_q=end_seq_q, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -354,7 +396,7 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor=all_input_ids_tensor, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, - max_tokens=max_tokens, + max_tokens=cumulative_max_length, ) @classmethod @@ -371,10 +413,12 @@ class FlashCausalLMBatch(Batch): input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - cu_seqlens = [0] - cu_seqlens_q = torch.arange( - 0, total_batch_size + 1, device=device, dtype=torch.int32 + start_seq = batches[0].start_seq.new_empty(total_batch_size) + end_seq = batches[0].end_seq.new_empty(total_batch_size) + start_seq_q = torch.arange( + 0, total_batch_size, device=device, dtype=torch.int32 ) + end_seq_q = start_seq_q + 1 max_seqlen = 0 past_key_values = [] @@ -389,7 +433,6 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_batch_size = 0 - cumulative_length = 0 max_tokens = 0 max_length = 0 @@ -410,18 +453,10 @@ class FlashCausalLMBatch(Batch): input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids - # Add cumulative lengths of all previous inputs - cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]]) - max_seqlen = max(max_seqlen, batch.max_seqlen) + start_seq[start_index:end_index] = batch.start_seq + max_tokens + end_seq[start_index:end_index] = batch.end_seq + max_tokens - if len(batch) != 1: - past_key_values.append(batch.past_key_values) - else: - # past was pre-allocated for this batch - # We need to slice to remove the padding - past_key_values.append( - batch.past_key_values[:, : batch.input_lengths[0]] - ) + max_seqlen = max(max_seqlen, batch.max_seqlen) all_input_ids.extend(batch.all_input_ids) @@ -431,9 +466,9 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) + past_key_values.append(batch.past_key_values) # Update - cumulative_length += batch.cu_seqlens[-1] cumulative_batch_size += len(batch) max_tokens += batch.max_tokens max_length = max( @@ -448,6 +483,9 @@ class FlashCausalLMBatch(Batch): ), ) + past_key_values = torch.cat(past_key_values, dim=0) + past_present_indices = end_seq - 1 + all_input_ids_tensor = torch.zeros( (total_batch_size, max_length), dtype=torch.int64, device=device ) @@ -463,11 +501,6 @@ class FlashCausalLMBatch(Batch): cumulative_batch_size += len(batch) - # Cat past - past_key_values = torch.cat(past_key_values, dim=1) - # Create final tensor on GPU - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype=dtype, device=device ) @@ -478,8 +511,13 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, - cu_seqlens=cu_seqlens, - cu_seqlens_q=cu_seqlens_q, + past_present_indices=past_present_indices, + start_seq=start_seq, + end_seq=end_seq, + start_seq_prefill=None, + end_seq_prefill=None, + start_seq_q=start_seq_q, + end_seq_q=end_seq_q, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -550,9 +588,12 @@ class FlashCausalLM(Model): self, input_ids: torch.Tensor, position_ids: torch.Tensor, - cu_seqlens: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor], + start_seq: torch.Tensor, + end_seq: torch.Tensor, + start_seq_q: Optional[torch.Tensor], + end_seq_q: Optional[torch.Tensor], max_s: int, + past_present_indices: torch.Tensor, past_key_values: Optional = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -561,9 +602,12 @@ class FlashCausalLM(Model): return self.model.forward( input_ids=input_ids, position_ids=position_ids, - cu_seqlens=cu_seqlens, - cu_seqlens_q=cu_seqlens_q, + start_seq=start_seq, + end_seq=end_seq, + start_seq_q=start_seq_q, + end_seq_q=end_seq_q, max_s=max_s, + past_present_indices=past_present_indices, past_key_values=past_key_values, pre_allocate_past_size=pre_allocate_past_size, lm_head_indices=lm_head_indices, @@ -575,23 +619,27 @@ class FlashCausalLM(Model): ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: prefill = batch.past_key_values is None prefill_logprobs = batch.prefill_next_token_indices is not None - single_request = len(batch) == 1 - if prefill and single_request: + if prefill: # Ask to pre-allocate kv to its max size - # == number of tokens + max_new_tokens - pre_allocate_past_size = ( - batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens - ) + # == Sum over batch size (number of tokens + max_new_tokens) - batch size + pre_allocate_past_size = batch.max_tokens + start_seq = batch.start_seq_prefill + end_seq = batch.end_seq_prefill else: pre_allocate_past_size = None + start_seq = batch.start_seq + end_seq = batch.end_seq out, present = self.forward( batch.input_ids, batch.position_ids, - batch.cu_seqlens, - batch.cu_seqlens_q, + start_seq, + end_seq, + batch.start_seq_q, + batch.end_seq_q, batch.max_seqlen, + batch.past_present_indices, batch.past_key_values, pre_allocate_past_size, batch.prefill_head_indices, @@ -614,55 +662,19 @@ class FlashCausalLM(Model): # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - # Create batch.cu_seqlens_q for decode - batch.cu_seqlens_q = torch.arange( - 0, len(batch) + 1, device=self.device, dtype=torch.int32 + # Create batch.start_seq_q and batch.end_seq_q for decode + batch.start_seq_q = torch.arange( + 0, len(batch), device=self.device, dtype=torch.int32 ) + batch.end_seq_q = batch.start_seq_q + 1 next_position_ids = batch.position_ids.new_empty(len(batch)) + # We do not need start_seq_prefill and end_seq_prefill anymore + batch.start_seq_prefill = None + batch.end_seq_prefill = None else: prefill_logprobs = None next_position_ids = batch.position_ids - # Prepare past for next decode - if len(batch) > 1: - # Used to slice next batch past - past_indices = torch.empty( - present.shape[1], dtype=torch.int64, device=self.device - ) - batch.past_key_values = present.new_empty( - ( - present.shape[0], - present.shape[1] + len(batch.requests), - *present.shape[2:], - ) - ) - - # It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow - # and will run asynchronously while we do the next for loop - cumulative_length = 0 - for i, input_length in enumerate(batch.input_lengths): - # Indexing metadata - start_index = cumulative_length - end_index = cumulative_length + input_length - - # Indices to copy present at the correct place in past_key_values - torch.arange( - start_index + i, - end_index + i, - dtype=torch.int64, - device=self.device, - out=past_indices[start_index:end_index], - ) - cumulative_length += input_length - - # Copy from present to past_key_values - batch.past_key_values[:, past_indices] = present - - # Initialize past_key_values in prefill for len(batch) == 1 - elif prefill: - # present is already pre-padded - batch.past_key_values = present - # Cumulative length cumulative_length = 0 @@ -685,6 +697,7 @@ class FlashCausalLM(Model): input_length, all_input_ids, ) in enumerate(iterator): + # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length @@ -718,7 +731,8 @@ class FlashCausalLM(Model): # Set values in batch batch.input_ids = next_input_ids batch.position_ids = next_position_ids + 1 - batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q + batch.past_present_indices = batch.end_seq + batch.end_seq = batch.end_seq + 1 if prefill and prefill_logprobs: # Get prefill logprobs @@ -843,6 +857,7 @@ class FlashCausalLM(Model): batch.prefill_head_indices = None batch.prefill_next_token_indices = None batch.max_seqlen = batch.max_seqlen + 1 + batch.past_key_values = present # No need to return a batch if we know that all requests stopped return generations, batch if not stopped else None