From 8897b896062eb15afcfa449c74f106c22cc4f5a2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 28 Nov 2023 22:23:03 +0000 Subject: [PATCH] Speculative medusa (illegal address Paged). --- .../custom_modeling/flash_llama_modeling.py | 11 +++++++ .../models/flash_causal_lm.py | 31 +++++++++++++------ server/text_generation_server/utils/tokens.py | 16 +++++++++- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 4aeb447d..a9327624 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -301,6 +301,7 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: + import ipdb;ipdb.set_trace() paged_attention.attention( attn_output, query, @@ -450,7 +451,15 @@ class FlashLlamaModel(torch.nn.Module): slots: torch.Tensor, input_lengths: torch.Tensor, max_s: int, + speculative_ids: Optional[torch.Tensor] ) -> torch.Tensor: + if speculative_ids is not None: + print(speculative_ids.shape, input_ids.shape) + new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0) + new_position_ids = (position_ids.view((1, -1)).expand(speculative_ids.shape[1] + 1, 1) + torch.arange(speculative_ids.shape[1] + 1).unsqueeze(1).to(device="cuda:0")).squeeze(0).squeeze(-1) + input_ids = new_input_ids + position_ids = new_position_ids + hidden_states = self.embed_tokens(input_ids) # Get rotary cos and sin for this forward @@ -501,6 +510,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, lm_head_indices: Optional[torch.Tensor] = None, + speculative_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: hidden_states = self.model( input_ids, @@ -511,6 +521,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): slots, input_lengths, max_s, + speculative_ids, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b1474e0d..f2486ddc 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch): # Decoder values input_ids: torch.Tensor position_ids: torch.Tensor + speculative_ids: torch.Tensor # Flash Attention values @@ -121,6 +122,7 @@ class FlashCausalLMBatch(Batch): )["input_ids"] position_ids = [] + speculative_ids = [] cu_seqlen_prefill = [0] cu_seqlen_speculative = [0] needed_blocks_slots = [] @@ -162,10 +164,11 @@ class FlashCausalLMBatch(Batch): tokenized_input = tokenized_input[-r.truncate :] - # TODO remove this - # Scaffolding to speculate some ids - speculate_ids = [1, 2] - tokenized_input.extend([1, 2]) + # # TODO remove this + # # Scaffolding to speculate some ids + # speculate_ids = [1, 2] + # tokenized_input.extend([1, 2]) + speculate_ids = [] input_length = len(tokenized_input) @@ -324,6 +327,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, blocks=blocks, max_blocks=max_blocks, + speculative_ids=None, ) @tracer.start_as_current_span("filter") @@ -739,6 +743,7 @@ class FlashCausalLM(Model): input_lengths=batch.input_lengths_tensor, max_s=batch.max_seqlen, lm_head_indices=batch.prefill_head_indices, + speculative_ids =batch.speculative_ids ) @tracer.start_as_current_span("generate_token") @@ -786,16 +791,17 @@ class FlashCausalLM(Model): next_token_logits = out - - import ipdb;ipdb.set_trace() + # if next_token_logits.shape[0] == 3: + # import ipdb;ipdb.set_trace() next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, speculative_logits + batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs ) + speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1] if prefill: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs @@ -803,13 +809,13 @@ class FlashCausalLM(Model): prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) if speculative_ids is not None: - # TODO - # length = len(batch) * speculative_ids.shape[1] + # length = len(batch) * (1 + speculative_length) length = len(batch) else: length = len(batch) # import ipdb;ipdb.set_trace() next_position_ids = batch.position_ids.new_empty(length) + # Keep only 1 slot index, TODO make sure we recover the speculated ids slots later batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] # We do not need cu_seqlen_prefill anymore batch.cu_seqlen_prefill = None @@ -836,6 +842,7 @@ class FlashCausalLM(Model): # It is faster if we delay this sync for the maximum amount of time # For each member of the batch + step = 1 + speculative_length for i, ( input_length, all_input_ids, @@ -852,6 +859,8 @@ class FlashCausalLM(Model): # Initialize position_ids # In decode, we do not need this as we can just increment position ids + # for j in range(1 + speculative_length): + # next_position_ids[i * step + j] = batch.position_ids[end_index - 1] + j next_position_ids[i] = batch.position_ids[end_index - 1] # Used to gather prefill logprobs @@ -872,7 +881,9 @@ class FlashCausalLM(Model): cumulative_length += input_length # Set values in batch + # batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1) batch.input_ids = next_input_ids + batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + 1 batch.input_lengths_tensor += 1 batch.slot_indices += 1 @@ -1031,6 +1042,8 @@ class FlashCausalLM(Model): batch.prefill_cu_outlens = None batch.prefill_head_indices = None batch.prefill_next_token_indices = None + if prefill: + batch.max_seqlen += speculative_length batch.max_seqlen = batch.max_seqlen + 1 return generations, batch diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index eee4f660..ab1ea83c 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -215,7 +215,7 @@ class HeterogeneousNextTokenChooser: self.dtype = dtype self.device = device - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculative_scores: Optional[torch.Tensor] = None): + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None): if self.watermark_processor is not None: scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: @@ -226,6 +226,20 @@ class HeterogeneousNextTokenChooser: next_ids = self.choice(scores) + if speculated_ids is not None: + validate_speculative = next_ids[1:] == speculated_ids[0] + index = 1 + for valid in validate_speculative.tolist(): + if valid: + index += 1 + print(f"Validated {index - 1}") + next_ids = next_ids[:index] + scores = scores[:index] + speculative_scores = speculative_scores[index - 1:index] + if index > 1: + import ipdb;ipdb.set_trace() + + logprobs = torch.log_softmax(scores, -1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)