From c9e7471742a669a374149b040da21859cd7bdf69 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 1 Jun 2023 13:32:48 +0200 Subject: [PATCH] working rw 7b --- .../custom_modeling/flash_rw_modeling.py | 48 +++++------------- .../models/flash_causal_lm.py | 49 +++++++++++++------ 2 files changed, 47 insertions(+), 50 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index f727f8ad..4d1ecd71 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,7 +1,6 @@ import torch import torch.distributed -from loguru import logger from torch import nn from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig @@ -139,7 +138,6 @@ class FlashRWAttention(torch.nn.Module): layer_past, layer_past_present_indices, prefill, - past_stream ): qkv = self.query_key_value(hidden_states) @@ -159,10 +157,8 @@ class FlashRWAttention(torch.nn.Module): # Prefill if prefill: - past_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(past_stream): - # Copy to layer past - layer_past[layer_past_present_indices] = kv + # Copy to layer past + layer_past[layer_past_present_indices] = kv # Expand to query shape kv = kv.expand(-1, 2, self.num_heads, self.head_size) @@ -190,7 +186,6 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - torch.cuda.current_stream().wait_stream(past_stream) # Add present to the layer_past tensor at the correct indices layer_past[layer_past_present_indices] = kv # Expand to query shape @@ -437,7 +432,6 @@ class FlashRWLayer(nn.Module): layer_past, layer_past_present_indices, prefill, - past_stream, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -454,7 +448,6 @@ class FlashRWLayer(nn.Module): layer_past, layer_past_present_indices, prefill, - past_stream ) mlp_output = self.mlp(ln_hidden_states) @@ -601,7 +594,6 @@ class FlashRWModel(FlashRWPreTrainedModel): ) self.head_size = self.h[0].self_attention.head_size - self.past_stream = torch.cuda.Stream() def forward( self, @@ -612,6 +604,7 @@ class FlashRWModel(FlashRWPreTrainedModel): start_seq_q, end_seq_q, max_s, + past_present_indices, past_key_values=None, pre_allocate_past_size: Optional[int] = None, ): @@ -623,33 +616,17 @@ class FlashRWModel(FlashRWPreTrainedModel): prefill = True - with torch.cuda.stream(self.past_stream): - # Create past tensor - past_key_values = hidden_states.new_zeros( - ( - len(self.h), - pre_allocate_past_size, - *self.cache_size, - ) + # Create past tensor + past_key_values = hidden_states.new_zeros( + ( + len(self.h), + pre_allocate_past_size, + *self.cache_size, ) - seq_indices = [] - for s, e in zip(start_seq, end_seq): - seq_indices.append( - torch.arange( - s, - e, - dtype=torch.int64, - device=self.device - ) - ) - layer_past_present_indices = torch.cat(seq_indices) - from loguru import logger - logger.error(f"layer past: {layer_past_present_indices}") + ) # Decode else: prefill = False - # Create indices from cumulative sequence lengths - layer_past_present_indices = end_seq - 1 # Get rotary cos and sin for this forward # Avoid to index in each layer @@ -670,9 +647,8 @@ class FlashRWModel(FlashRWPreTrainedModel): end_seq_q, max_s, past_key_values[i], - layer_past_present_indices, + past_present_indices, prefill, - self.past_stream ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -699,6 +675,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): start_seq_q, end_seq_q, max_s, + past_present_indices, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -711,6 +688,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): start_seq_q, end_seq_q, max_s, + past_present_indices, past_key_values, pre_allocate_past_size, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d3323cb8..ed7a0ec2 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -34,6 +34,9 @@ class FlashCausalLMBatch(Batch): input_ids: torch.Tensor position_ids: torch.Tensor + # Indices to copy present to the correct indices is the pre-allocated past key values + past_present_indices: torch.Tensor + # tensor of length b holding starting offset of each sequence start_seq: torch.Tensor # tensor of length b holding ending offset of each sequence @@ -98,6 +101,7 @@ class FlashCausalLMBatch(Batch): )["input_ids"] position_ids = [] + past_present_indices = [] start_seq = [] end_seq = [] start_seq_prefill = [] @@ -182,6 +186,10 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 + request_past_present_indices = np.zeros(input_length + max_new_tokens - 1) + request_past_present_indices[:input_length] = 1 + past_present_indices.append(request_past_present_indices) + # Update # Remove one as the first token des not have a past cumulative_length += input_length @@ -214,13 +222,20 @@ class FlashCausalLMBatch(Batch): position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32) end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32) + if len(pb.requests) > 1: + past_present_indices = np.concatenate(past_present_indices) + start_seq_prefill = torch.tensor(start_seq_prefill, device=device, dtype=torch.int32) end_seq_prefill = torch.tensor(end_seq_prefill, device=device, dtype=torch.int32) else: + past_present_indices = past_present_indices[0] + start_seq_prefill = start_seq end_seq_prefill = end_seq + past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.bool) + if all_prefill_logprobs: prefill_head_indices = None prefill_next_token_indices = end_seq - 1 @@ -241,6 +256,7 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, + past_present_indices=past_present_indices, start_seq=start_seq, end_seq=end_seq, start_seq_prefill=start_seq_prefill, @@ -270,7 +286,7 @@ class FlashCausalLMBatch(Batch): if len(request_ids) == len(self): return self - single_request = len(request_ids) == 1 + device = self.input_ids.device # Cumulative length cumulative_max_length = 0 @@ -281,13 +297,15 @@ class FlashCausalLMBatch(Batch): # Used to index into tensors indices = [] + # past indices to keep + past_indices = torch.zeros(self.past_key_values.shape[1], dtype=torch.bool, device=device) + # Create on CPU to only move to GPU once instead of at every copy start_seq = torch.empty(len(request_ids), dtype=torch.int32) end_seq = torch.empty(len(request_ids), dtype=torch.int32) start_seq_q = self.start_seq_q[: len(request_ids)] end_seq_q = self.end_seq_q[: len(request_ids)] max_seqlen = 0 - past_key_values = [] requests = [] all_input_ids = [] @@ -324,11 +342,8 @@ class FlashCausalLMBatch(Batch): start_seq[i] = cumulative_max_length end_seq[i] = cumulative_max_length + request_input_length - # Slice from past - past_key_values.append( - self.past_key_values[:, - self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1] - ) + # Set slice + past_indices[self.start_seq[idx]: self.end_seq[idx] + remaining_tokens - 1] = True cumulative_max_length += request_input_length + remaining_tokens - 1 @@ -337,16 +352,12 @@ class FlashCausalLMBatch(Batch): position_ids = self.position_ids[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) - - if single_request: - past_key_values = past_key_values[0] - else: - # Cat all past - past_key_values = torch.cat(past_key_values, dim=1) + past_key_values = self.past_key_values[:, past_indices] # Move to GPU now that we have the whole tensor - start_seq = start_seq.to(self.start_seq.device) - end_seq = end_seq.to(self.start_seq.device) + start_seq = start_seq.to(device) + end_seq = end_seq.to(device) + past_present_indices = end_seq - 1 return FlashCausalLMBatch( batch_id=self.batch_id, @@ -354,6 +365,7 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, + past_present_indices=past_present_indices, start_seq=start_seq, end_seq=end_seq, start_seq_prefill=None, @@ -468,6 +480,8 @@ class FlashCausalLMBatch(Batch): ), ) + past_present_indices = end_seq - 1 + all_input_ids_tensor = torch.zeros( (total_batch_size, max_length), dtype=torch.int64, device=device ) @@ -493,6 +507,7 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, + past_present_indices=past_present_indices, start_seq=start_seq, end_seq=end_seq, start_seq_prefill=None, @@ -574,6 +589,7 @@ class FlashCausalLM(Model): start_seq_q: Optional[torch.Tensor], end_seq_q: Optional[torch.Tensor], max_s: int, + past_present_indices: torch.Tensor, past_key_values: Optional = None, pre_allocate_past_size: Optional[int] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -587,6 +603,7 @@ class FlashCausalLM(Model): start_seq_q=start_seq_q, end_seq_q=end_seq_q, max_s=max_s, + past_present_indices=past_present_indices, past_key_values=past_key_values, pre_allocate_past_size=pre_allocate_past_size, lm_head_indices=lm_head_indices, @@ -619,6 +636,7 @@ class FlashCausalLM(Model): batch.start_seq_q, batch.end_seq_q, batch.max_seqlen, + batch.past_present_indices, batch.past_key_values, pre_allocate_past_size, batch.prefill_head_indices, @@ -708,6 +726,7 @@ class FlashCausalLM(Model): # Set values in batch batch.input_ids = next_input_ids batch.position_ids = next_position_ids + 1 + batch.past_present_indices = torch.clone(batch.end_seq) batch.end_seq += 1 if prefill and prefill_logprobs: