From 866af9b9fdc8c590140247adb38239139524b230 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 29 Nov 2023 14:36:17 +0000 Subject: [PATCH] Speedup 2x. - Still wrong when batched - Incorrect returned payload. (No multiple ids/logprobs) --- .../custom_modeling/flash_llama_modeling.py | 16 ++++- .../models/flash_causal_lm.py | 69 +++++++++++++------ server/text_generation_server/utils/tokens.py | 14 ++-- 3 files changed, 69 insertions(+), 30 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a9327624..06dd3f5c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -301,7 +301,6 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - import ipdb;ipdb.set_trace() paged_attention.attention( attn_output, query, @@ -454,9 +453,20 @@ class FlashLlamaModel(torch.nn.Module): speculative_ids: Optional[torch.Tensor] ) -> torch.Tensor: if speculative_ids is not None: - print(speculative_ids.shape, input_ids.shape) + speculative_length = speculative_ids.shape[1] + new_length = speculative_length + 1 new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0) - new_position_ids = (position_ids.view((1, -1)).expand(speculative_ids.shape[1] + 1, 1) + torch.arange(speculative_ids.shape[1] + 1).unsqueeze(1).to(device="cuda:0")).squeeze(0).squeeze(-1) + new_position_ids = (position_ids.view((1, -1)).expand(new_length, 1) + torch.arange(new_length).unsqueeze(1).to(device=position_ids.device)).squeeze(0).squeeze(-1) + + # Add an extra block just in case + block_tables = torch.cat([block_tables, block_tables[:, -1:] + 1], dim=1) + # Add Copy the block tables for all members + block_tables = block_tables.expand(new_length, -1).contiguous() + slots = slots.expand(new_length) + torch.arange(new_length, dtype=slots.dtype).to(device=slots.device) + input_lengths = input_lengths.expand(new_length) + torch.arange(new_length, dtype=input_lengths.dtype).to(device=input_lengths.device) + max_s = max_s + speculative_length + + input_ids = new_input_ids position_ids = new_position_ids diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f2486ddc..3b3aa400 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -793,7 +793,7 @@ class FlashCausalLM(Model): # if next_token_logits.shape[0] == 3: # import ipdb;ipdb.set_trace() - next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser( + next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits ) @@ -835,6 +835,7 @@ class FlashCausalLM(Model): iterator = zip( batch.input_lengths, batch.all_input_ids, + accepted_ids ) # We do two for loops as the first one can run completely asynchronously from the GPU while for the second @@ -842,10 +843,11 @@ class FlashCausalLM(Model): # It is faster if we delay this sync for the maximum amount of time # For each member of the batch - step = 1 + speculative_length + index = 0 for i, ( input_length, all_input_ids, + n_accepted_ids ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length @@ -859,8 +861,6 @@ class FlashCausalLM(Model): # Initialize position_ids # In decode, we do not need this as we can just increment position ids - # for j in range(1 + speculative_length): - # next_position_ids[i * step + j] = batch.position_ids[end_index - 1] + j next_position_ids[i] = batch.position_ids[end_index - 1] # Used to gather prefill logprobs @@ -876,17 +876,29 @@ class FlashCausalLM(Model): start_index + 1 : start_index + out_length ] - batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] + for j in range(n_accepted_ids): + batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index] + index += 1 cumulative_length += input_length + + # if accepted_ids[0] > 1: + # import ipdb;ipdb.set_trace() + + if len(accepted_ids) > 1: + raise Exception("Implemtent the batched behavior") + # Set values in batch # batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1) - batch.input_ids = next_input_ids - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + 1 - batch.input_lengths_tensor += 1 - batch.slot_indices += 1 + + for n_accepted_ids in accepted_ids: + # TODO Make this batched + batch.input_ids = next_input_ids[-1:] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + n_accepted_ids + batch.input_lengths_tensor += n_accepted_ids + batch.slot_indices += n_accepted_ids if prefill and prefill_logprobs: # Get prefill logprobs @@ -899,7 +911,7 @@ class FlashCausalLM(Model): # GPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() - next_token_ids = batch.input_ids.tolist() + next_token_ids = next_input_ids.tolist() # Zipped iterator iterator = zip( @@ -912,13 +924,15 @@ class FlashCausalLM(Model): batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, - next_token_ids, - next_token_logprobs, + # next_token_ids, + # next_token_logprobs, + accepted_ids, batch_top_token_ids, batch_top_token_logprobs, ) # For each member of the batch + index = 0 for i, ( request, input_length, @@ -929,13 +943,16 @@ class FlashCausalLM(Model): do_sample, seed, top_n_tokens, - next_token_id, - next_token_logprob, + # next_token_id, + # next_token_logprob, + n_accepted_ids, top_token_ids, top_token_logprobs, ) in enumerate(iterator): # Append next token to all tokens - all_input_ids.append(next_token_id) + _next_token_ids = next_token_ids[index: index+n_accepted_ids] + _next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids] + all_input_ids.extend(_next_token_ids) # Generated token next_token_text, prefix_offset, read_offset = self.decode_token( @@ -945,13 +962,18 @@ class FlashCausalLM(Model): ) # Evaluate stopping criteria - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - if not stop: - stopped = False + for next_token_id in _next_token_ids: + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) + + if stop: + stopped = True + break + if not stop: + stopped = False # Shard generations # All generations will be appended in the rust sharded client @@ -1015,6 +1037,9 @@ class FlashCausalLM(Model): else: top_tokens = None + next_token_ids = _next_token_ids[0] + next_token_logprob = _next_token_logprobs[0] + generation = Generation( request.id, prefill_tokens, diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index ab1ea83c..c5e07cca 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -225,19 +225,21 @@ class HeterogeneousNextTokenChooser: scores = warper(input_ids, scores) + accepted_ids = [] next_ids = self.choice(scores) if speculated_ids is not None: - validate_speculative = next_ids[1:] == speculated_ids[0] + validate_speculative = next_ids[:-1] == speculated_ids[0] index = 1 for valid in validate_speculative.tolist(): if valid: index += 1 - print(f"Validated {index - 1}") + # print(f"Validated {index - 1}") next_ids = next_ids[:index] scores = scores[:index] speculative_scores = speculative_scores[index - 1:index] - if index > 1: - import ipdb;ipdb.set_trace() + accepted_ids.append(index) + else: + accepted_ids.append(1) logprobs = torch.log_softmax(scores, -1) @@ -255,10 +257,12 @@ class HeterogeneousNextTokenChooser: # for warper in self.warpers: # speculative_scores = warper(input_ids, speculative_scores) speculative_ids = Greedy()(speculative_scores) + # # Ignore first head, it seems to be a regular head. + # speculative_ids = speculative_ids[:, 1:] else: speculative_ids = None - return next_ids, next_logprobs, logprobs, speculative_ids + return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids def filter(self, indices): if self.watermark_processor is not None: