From bfd6928c3e36c0dbab6bfb06ba6569d8784eb452 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 1 Jun 2023 18:37:14 +0200 Subject: [PATCH] working --- .../custom_modeling/flash_rw_modeling.py | 36 ++--- .../flash_santacoder_modeling.py | 125 ++++++++++-------- .../models/flash_causal_lm.py | 48 +++---- 3 files changed, 108 insertions(+), 101 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 4d1ecd71..c65fd160 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -136,7 +136,7 @@ class FlashRWAttention(torch.nn.Module): end_seq_q, max_s, layer_past, - layer_past_present_indices, + past_present_indices, prefill, ): qkv = self.query_key_value(hidden_states) @@ -153,12 +153,12 @@ class FlashRWAttention(torch.nn.Module): # Inplace rotary self.rotary_emb(query, cos, sin) - self.rotary_emb(kv[:, 0], cos, sin) + self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin) # Prefill if prefill: # Copy to layer past - layer_past[layer_past_present_indices] = kv + layer_past[past_present_indices] = kv # Expand to query shape kv = kv.expand(-1, 2, self.num_heads, self.head_size) @@ -167,8 +167,8 @@ class FlashRWAttention(torch.nn.Module): # flash attention flash_attn_cuda_modif.fwd( query, - kv[:, 0], - kv[:, 1], + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, start_seq, end_seq, @@ -187,7 +187,7 @@ class FlashRWAttention(torch.nn.Module): # Decode else: # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = kv + layer_past[past_present_indices] = kv # Expand to query shape kv = layer_past.expand(-1, 2, self.num_heads, self.head_size) @@ -196,8 +196,8 @@ class FlashRWAttention(torch.nn.Module): # flash attention flash_attn_cuda_modif.fwd( query, - kv[:, 0], - kv[:, 1], + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), attn_output, start_seq_q, end_seq_q, @@ -271,7 +271,7 @@ class FlashRWLargeAttention(torch.nn.Module): cu_seqlens, max_s, layer_past, - layer_past_present_indices, + past_present_indices, cu_seqlens_q, ): qkv = self.query_key_value(hidden_states) @@ -290,7 +290,7 @@ class FlashRWLargeAttention(torch.nn.Module): self.rotary_emb(kv[:, :, 0], cos, sin) # Prefill - if layer_past_present_indices is None: + if past_present_indices is None: # Copy to layer past layer_past[...] = kv # Expand to query shape @@ -323,7 +323,7 @@ class FlashRWLargeAttention(torch.nn.Module): # Decode else: # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = kv + layer_past[past_present_indices] = kv # Expand to query shape kv = ( layer_past.unsqueeze(2) @@ -430,7 +430,7 @@ class FlashRWLayer(nn.Module): end_seq_q, max_s, layer_past, - layer_past_present_indices, + past_present_indices, prefill, ): if self.parallel_attn: @@ -446,7 +446,7 @@ class FlashRWLayer(nn.Module): end_seq_q, max_s, layer_past, - layer_past_present_indices, + past_present_indices, prefill, ) @@ -469,7 +469,7 @@ class FlashRWLayer(nn.Module): end_seq_q, max_s, layer_past, - layer_past_present_indices, + past_present_indices, prefill, ) @@ -517,7 +517,7 @@ class FlashRWLargeLayer(nn.Module): cu_seqlens, max_s, layer_past, - layer_past_present_indices, + past_present_indices, cu_seqlens_q, ): ln_attn, residual = self.ln_attn(hidden_states, residual) @@ -531,7 +531,7 @@ class FlashRWLargeLayer(nn.Module): cu_seqlens, max_s, layer_past, - layer_past_present_indices, + past_present_indices, cu_seqlens_q, ) @@ -619,8 +619,8 @@ class FlashRWModel(FlashRWPreTrainedModel): # Create past tensor past_key_values = hidden_states.new_zeros( ( - len(self.h), pre_allocate_past_size, + len(self.h), *self.cache_size, ) ) @@ -646,7 +646,7 @@ class FlashRWModel(FlashRWPreTrainedModel): start_seq_q, end_seq_q, max_s, - past_key_values[i], + past_key_values[:, i], past_present_indices, prefill, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index c22aac65..70e02e76 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -7,6 +7,7 @@ from typing import Optional # Flash attention imports import flash_attn_cuda_modif + from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -148,11 +149,14 @@ class FlashMQAttention(torch.nn.Module): def forward( self, hidden_states, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): qkv = self.c_attn(hidden_states) @@ -166,9 +170,9 @@ class FlashMQAttention(torch.nn.Module): key_value = key_value.view(-1, 2, 1, self.head_size) # Prefill - if layer_past_present_indices is None: + if prefill: # Copy to layer past - layer_past[...] = key_value + layer_past[past_present_indices] = key_value # Expand from 1 to num_heads key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) @@ -177,11 +181,13 @@ class FlashMQAttention(torch.nn.Module): # flash attention flash_attn_cuda_modif.fwd( query, - key_value[:, 0], - key_value[:, 1], + torch.select(key_value, dim=1, index=0), + torch.select(key_value, dim=1, index=1), attn_output, - cu_seqlens, - cu_seqlens, + start_seq, + end_seq, + start_seq, + end_seq, max_s, max_s, 0.0, @@ -195,7 +201,7 @@ class FlashMQAttention(torch.nn.Module): # Decode else: # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = key_value + layer_past[past_present_indices] = key_value # Expand from 1 to num_heads key_value = layer_past.expand(-1, 2, self.num_heads, self.head_size) @@ -204,11 +210,13 @@ class FlashMQAttention(torch.nn.Module): # flash attention flash_attn_cuda_modif.fwd( query, - key_value[:, 0], - key_value[:, 1], + torch.select(key_value, dim=1, index=0), + torch.select(key_value, dim=1, index=1), attn_output, - cu_seqlens_q, - cu_seqlens, + start_seq_q, + end_seq_q, + start_seq, + end_seq, 1, max_s, 0.0, @@ -277,21 +285,27 @@ class Block(nn.Module): self, hidden_states, residual, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ): hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.attn( hidden_states, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, layer_past, - layer_past_present_indices, - cu_seqlens_q, + past_present_indices, + prefill, ) hidden_states, residual = self.ln_2(hidden_states, residual) @@ -339,10 +353,13 @@ class FlashSantacoderModel(nn.Module): self, input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - past_key_values: Optional[torch.Tensor] = None, + past_present_indices, + past_key_values=None, pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -350,43 +367,37 @@ class FlashSantacoderModel(nn.Module): # Prefill if past_key_values is None: + assert pre_allocate_past_size is not None + + prefill = True + # Create past tensor - past_key_values = hidden_states.new_empty( + past_key_values = hidden_states.new_zeros( ( + pre_allocate_past_size, len(self.h), - len(hidden_states) - if pre_allocate_past_size is None - else pre_allocate_past_size, 2, 1, - self.head_size, + self.head_size ) ) - layer_past_present_indices = None - slice_past_index = len(hidden_states) # Decode else: - # Create indices from cumulative sequence lengths - layer_past_present_indices = cu_seqlens[1:] - 1 - slice_past_index = None + prefill = False residual = None for i, layer in enumerate(self.h): - # We added padding that we now need to slice - layer_past_key_values = ( - past_key_values[i] - if slice_past_index is None - else past_key_values[i, :slice_past_index] - ) - hidden_states, residual = layer( hidden_states, residual, - cu_seqlens, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, - layer_past_key_values, - layer_past_present_indices, - cu_seqlens_q, + torch.select(past_key_values, dim=1, index=i), + past_present_indices, + prefill, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -404,21 +415,27 @@ class FlashSantacoderForCausalLM(nn.Module): def forward( self, - input_ids, - position_ids, - cu_seqlens, - cu_seqlens_q, - max_s, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, - lm_head_indices: Optional[torch.Tensor] = None, + input_ids, + position_ids, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + past_present_indices, + past_key_values: Optional[torch.Tensor] = None, + pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, ): hidden_states, present = self.transformer( input_ids, position_ids, - cu_seqlens, - cu_seqlens_q, + start_seq, + end_seq, + start_seq_q, + end_seq_q, max_s, + past_present_indices, past_key_values, pre_allocate_past_size, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index ed7a0ec2..1a8bf6fc 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -186,8 +186,7 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - request_past_present_indices = np.zeros(input_length + max_new_tokens - 1) - request_past_present_indices[:input_length] = 1 + request_past_present_indices = torch.arange(cumulative_max_length, cumulative_max_length + input_length, dtype=torch.int64) past_present_indices.append(request_past_present_indices) # Update @@ -210,10 +209,20 @@ class FlashCausalLMBatch(Batch): if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) + + past_present_indices = np.concatenate(past_present_indices, dtype=np.int64) + + start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32) + end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32) else: input_ids = all_input_ids[0] position_ids = position_ids[0] + past_present_indices = past_present_indices[0] + + start_seq_prefill = start_seq + end_seq_prefill = end_seq + # Create tensors on device input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) all_input_ids_tensor = torch.tensor( @@ -222,19 +231,7 @@ class FlashCausalLMBatch(Batch): position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32) end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32) - - if len(pb.requests) > 1: - past_present_indices = np.concatenate(past_present_indices) - - start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32) - end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32) - else: - past_present_indices = past_present_indices[0] - - start_seq_prefill = start_seq - end_seq_prefill = end_seq - - past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.bool) + past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64) if all_prefill_logprobs: prefill_head_indices = None @@ -298,7 +295,7 @@ class FlashCausalLMBatch(Batch): indices = [] # past indices to keep - past_indices = torch.zeros(self.past_key_values.shape[1], dtype=torch.bool, device=device) + past_indices = torch.zeros(self.past_key_values.shape[0], dtype=torch.bool, device=device) # Create on CPU to only move to GPU once instead of at every copy start_seq = torch.empty(len(request_ids), dtype=torch.int32) @@ -352,7 +349,7 @@ class FlashCausalLMBatch(Batch): position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) - past_key_values = self.past_key_values[:, past_indices] + past_key_values = self.past_key_values[past_indices] # Move to GPU now that we have the whole tensor start_seq = start_seq.to(device) @@ -409,11 +406,7 @@ class FlashCausalLMBatch(Batch): ) end_seq_q = start_seq_q + 1 max_seqlen = 0 - past_key_values = batches[0].past_key_values.new_empty(( - batches[0].past_key_values.shape[0], - total_tokens, - *batches[0].past_key_values.shape[2:] - )) + past_key_values = [] all_input_ids = [] @@ -449,11 +442,6 @@ class FlashCausalLMBatch(Batch): start_seq[start_index:end_index] = batch.start_seq + max_tokens end_seq[start_index:end_index] = batch.end_seq + max_tokens - past_key_values[ - :, - max_tokens: max_tokens + batch.max_tokens - ] = batch.past_key_values - max_seqlen = max(max_seqlen, batch.max_seqlen) all_input_ids.extend(batch.all_input_ids) @@ -464,6 +452,7 @@ class FlashCausalLMBatch(Batch): next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) stopping_criterias.extend(batch.stopping_criterias) + past_key_values.append(batch.past_key_values) # Update cumulative_batch_size += len(batch) @@ -480,6 +469,7 @@ class FlashCausalLMBatch(Batch): ), ) + past_key_values = torch.cat(past_key_values, dim=0) past_present_indices = end_seq - 1 all_input_ids_tensor = torch.zeros( @@ -726,8 +716,8 @@ class FlashCausalLM(Model): # Set values in batch batch.input_ids = next_input_ids batch.position_ids = next_position_ids + 1 - batch.past_present_indices = torch.clone(batch.end_seq) - batch.end_seq += 1 + batch.past_present_indices = batch.end_seq + batch.end_seq = batch.end_seq + 1 if prefill and prefill_logprobs: # Get prefill logprobs