From a4df5bc64a87958608b22f7956741dc29ba9df23 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 23 Mar 2023 14:01:35 +0100 Subject: [PATCH] faster --- .../models/flash_neox.py | 56 +++++++++++++------ 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 206e39e7..97b9f0b5 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -47,7 +47,10 @@ class FlashNeoXBatch(Batch): past_key_values: Optional[torch.Tensor] # All tokens - all_input_ids: List[torch.Tensor] + all_input_ids: List[List[int]] + + # Lengths of all generations present in the batch + input_lengths: List[int] # Generation helpers next_token_choosers: List[NextTokenChooser] @@ -70,6 +73,9 @@ class FlashNeoXBatch(Batch): cu_seqlens = [0] max_seqlen = 0 + input_lengths = [] + all_input_ids = [] + next_token_choosers = [] stopping_criterias = [] @@ -81,9 +87,11 @@ class FlashNeoXBatch(Batch): .squeeze(0) ) input_ids.append(tokenized_input) + all_input_ids.append(tokenized_input.tolist()) position_ids.append( torch.arange(0, len(tokenized_input), dtype=torch.int32, device=device) ) + input_lengths.append(len(tokenized_input)) cu_seqlens.append(len(tokenized_input)) max_seqlen = max(max_seqlen, len(tokenized_input)) @@ -92,7 +100,6 @@ class FlashNeoXBatch(Batch): StoppingCriteria.from_pb(r.stopping_parameters, tokenizer) ) - all_input_ids = input_ids input_ids = torch.concat(input_ids).unsqueeze(1) position_ids = torch.concat(position_ids) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) @@ -105,6 +112,7 @@ class FlashNeoXBatch(Batch): cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, past_key_values=None, + input_lengths=input_lengths, all_input_ids=all_input_ids, next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, @@ -191,6 +199,8 @@ class FlashNeoX(Model): batch.past_key_values, ) + device = out.device + # List of indices to cache next_batch_keep_indices = [] @@ -200,14 +210,19 @@ class FlashNeoX(Model): next_batch_cu_seqlens = [0] next_batch_max_seqlen = 0 next_batch_past_key_values = [] + next_batch_input_lengths = [] next_batch_all_input_ids = [] + # Cumulative length + cumulative_length = 0 + # Results generations: List[Generation] = [] # Zipped iterator iterator = zip( batch.requests, + batch.input_lengths, batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, @@ -216,14 +231,14 @@ class FlashNeoX(Model): # For each member of the batch for i, ( request, + input_length, next_token_chooser, stopping_criteria, all_input_ids, ) in enumerate(iterator): # Indexing metadata - start_index = batch.cu_seqlens[i] - end_index = batch.cu_seqlens[i + 1] - seq_length = (end_index - start_index).item() + start_index = cumulative_length + end_index = cumulative_length + input_length if batch.past_key_values is None: # Prefill mode @@ -236,23 +251,28 @@ class FlashNeoX(Model): # Select next token next_token_id, logprobs = next_token_chooser( - all_input_ids.view(1, -1), logits + all_input_ids, logits ) + next_token_id = next_token_id.to("cpu") + logprobs = logprobs.to("cpu") + + next_token_id_squeezed = next_token_id.squeeze() + next_token_id_item = next_token_id_squeezed.item() # Append next token to all tokens - all_input_ids = torch.cat([all_input_ids, next_token_id.squeeze(1)]) - new_input_length = seq_length + 1 + all_input_ids.append(next_token_id_item) + # all_input_ids = torch.cat([all_input_ids, next_token_id.squeeze(1)]) + new_input_length = input_length + 1 # Generated token next_token_logprob = logprobs[-1, next_token_id] - next_token_id_squeezed = next_token_id.squeeze() next_token_text = self.decode_token( - next_token_id_squeezed, + next_token_id_item, ) # Evaluate stopping criteria stop, reason = stopping_criteria( - next_token_id_squeezed, + next_token_id_item, next_token_text, ) @@ -279,10 +299,11 @@ class FlashNeoX(Model): generated_text = None next_batch_keep_indices.append(i) next_batch_input_ids.append(next_token_id) - next_batch_position_ids.append(seq_length) + next_batch_position_ids.append(input_length) next_batch_cu_seqlens.append( next_batch_cu_seqlens[i] + new_input_length ) + next_batch_input_lengths.append(new_input_length) next_batch_all_input_ids.append(all_input_ids) next_batch_max_seqlen = max(next_batch_max_seqlen, new_input_length) @@ -290,7 +311,7 @@ class FlashNeoX(Model): if stopping_criteria.current_tokens == 1: # Remove generated token to only have prefill and add nan for first prompt token prefill_logprobs = [float("nan")] + logprobs.gather( - 1, all_input_ids[1:].unsqueeze(1) + 1, torch.tensor(all_input_ids[1:]).unsqueeze(1) ).squeeze(1)[:-1].tolist() prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( @@ -307,14 +328,15 @@ class FlashNeoX(Model): generation = Generation( request.id, prefill_tokens, - next_token_id_squeezed, + next_token_id_item, next_token_logprob, next_token_text, - next_token_id_squeezed.item() in self.all_special_ids, + next_token_id_item in self.all_special_ids, generated_text, ) generations.append(generation) + cumulative_length += input_length # We finished all generations in the batch; there is no next batch if not next_batch_keep_indices: @@ -337,7 +359,6 @@ class FlashNeoX(Model): next_batch_stopping_criterias = batch.stopping_criterias # Create final next batch tensors - device = out.device next_batch_position_ids = torch.tensor( next_batch_position_ids, dtype=torch.int32, device=device ) @@ -348,7 +369,7 @@ class FlashNeoX(Model): next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0) next_batch_past_key_values = torch.concat(next_batch_past_key_values) else: - next_batch_input_ids = next_batch_input_ids[0] + next_batch_input_ids = next_batch_input_ids[0].to(device) next_batch_past_key_values = next_batch_past_key_values[0] next_batch = FlashNeoXBatch( @@ -359,6 +380,7 @@ class FlashNeoX(Model): cu_seqlens=next_batch_cu_seqlens, max_seqlen=next_batch_max_seqlen, past_key_values=next_batch_past_key_values, + input_lengths=next_batch_input_lengths, all_input_ids=next_batch_all_input_ids, next_token_choosers=next_batch_next_token_choosers, stopping_criterias=next_batch_stopping_criterias,