diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 61ccca84..6e072d4c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -431,19 +431,15 @@ class FlashCausalLM(Model): # Shortcut when batch_size == 1 if len(batch) == 1: input_ids = batch.input_ids[0].view(-1) - # No need to slice as flash attention will take care of it with cu_seqlens - past_key_values = batch.past_key_values else: # Concatenate tensors - input_ids = torch.cat(batch.input_ids).view(-1) - past_key_values = ( - torch.cat(batch.past_key_values, dim=1) - if batch.past_key_values is not None - else None - ) + if not isinstance(batch.input_ids, torch.Tensor): + input_ids = torch.cat(batch.input_ids).view(-1) + else: + input_ids = batch.input_ids.view(-1) # if prefill and bs == 1 - if past_key_values is None and len(batch) == 1: + if batch.past_key_values is None and len(batch) == 1: # Ask to pre-allocate kv to its max size # == number of tokens + max_new_tokens pre_allocate_past_size = ( @@ -453,11 +449,11 @@ class FlashCausalLM(Model): pre_allocate_past_size = None # Concatenate when prefill, torch.tensor when decode - position_ids = ( - torch.tensor(batch.position_ids, device=self.device) - if batch.past_key_values is not None - else torch.cat(batch.position_ids) - ) + if batch.past_key_values is None: + position_ids = torch.cat(batch.position_ids) + else: + position_ids = batch.position_ids + cu_seqlens = torch.tensor( batch.cu_seqlens, device=self.device, dtype=torch.int32 ) @@ -467,28 +463,10 @@ class FlashCausalLM(Model): position_ids, cu_seqlens, batch.max_seqlen, - past_key_values, + batch.past_key_values, pre_allocate_past_size, ) - # Initialize past_key_values in prefill - if batch.past_key_values is None: - # Initialize past padding tensor - if self.past_pad is None: - self.past_pad = present.new_zeros( - present.shape[0], 1, *present.shape[2:] - ) - # Set in batch in case it needs to be used later in concatenate() - batch.past_pad = self.past_pad - if len(batch) == 1: - # present is already pre-padded - batch.past_key_values = present - else: - # Add padding after each sequence - # This will have the correct shape after the final past_key_values concatenation before the model - # forward - batch.past_key_values = [None, self.past_pad] * len(batch) - # Cumulative length cumulative_length = 0 @@ -508,6 +486,9 @@ class FlashCausalLM(Model): batch.all_input_ids_tensor, ) + next_input_ids = input_ids.new_empty(len(batch.requests)) + past_indices = [] + # For each member of the batch for i, ( request, @@ -538,14 +519,56 @@ class FlashCausalLM(Model): all_input_ids_tensor[None, :input_length], logits ) next_token_id_squeezed = next_token_id.squeeze() - next_token_id_item = next_token_id_squeezed.item() + all_input_ids_tensor[input_length] = next_token_id_squeezed + next_input_ids[i] = next_token_id_squeezed + past_indices.extend([j for j in range(start_index + i, end_index + i)]) + + # Initialize past_key_values in prefill + if batch.past_key_values is None and len(batch) == 1: + # present is already pre-padded + batch.past_key_values = present + + if len(batch) > 1: + batch.past_key_values = present.new_empty((present.shape[0], present.shape[1] + len(batch.requests), *present.shape[2:])) + batch.past_key_values[:, past_indices] = present + + if prefill: + batch.position_ids = torch.tensor(batch.input_lengths, device=self.device) + else: + batch.position_ids = batch.position_ids + 1 + + next_token_ids = next_input_ids.tolist() + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.offsets, + batch.token_offsets, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + batch.all_input_ids_tensor, + ) + + # For each member of the batch + for i, ( + request, + input_length, + offset, + token_offset, + next_token_chooser, + stopping_criteria, + all_input_ids, + all_input_ids_tensor, + ) in enumerate(iterator): + next_token_id_item = next_token_ids[i] # Append next token to all tokens all_input_ids.append(next_token_id_item) - all_input_ids_tensor[input_length] = next_token_id_item # Generated token - next_token_logprob = logprobs[-1, next_token_id_item] + next_token_logprob = 0.0 next_token_text, offset, token_offset = self.decode_token( all_input_ids, offset, @@ -576,23 +599,24 @@ class FlashCausalLM(Model): stopped = False generated_text = None - # Prefill - if prefill: - # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + logprobs.gather( - 1, all_input_ids_tensor[1:input_length].unsqueeze(1) - ).squeeze(1)[:-1].tolist() - prefill_token_ids = all_input_ids[:-1] - prefill_texts = self.tokenizer.batch_decode( - prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - prefill_tokens = PrefillTokens( - prefill_token_ids, prefill_logprobs, prefill_texts - ) - else: - prefill_tokens = None + prefill = stopping_criteria.current_tokens == 0 + # # Prefill + # if prefill: + # # Remove generated token to only have prefill and add nan for first prompt token + # prefill_logprobs = [float("nan")] + logprobs.gather( + # 1, all_input_ids_tensor[1:input_length].unsqueeze(1) + # ).squeeze(1)[:-1].tolist() + # prefill_token_ids = all_input_ids[:-1] + # prefill_texts = self.tokenizer.batch_decode( + # prefill_token_ids, + # clean_up_tokenization_spaces=False, + # skip_special_tokens=False, + # ) + # prefill_tokens = PrefillTokens( + # prefill_token_ids, prefill_logprobs, prefill_texts + # ) + # else: + prefill_tokens = None generation = Generation( request.id, @@ -609,19 +633,16 @@ class FlashCausalLM(Model): new_input_length = input_length + 1 # Update values - batch.input_ids[i] = next_token_id - batch.position_ids[i] = input_length batch.input_lengths[i] = new_input_length batch.offsets[i] = offset batch.token_offsets[i] = token_offset batch.all_input_ids[i] = all_input_ids - batch.all_input_ids_tensor[i] = all_input_ids_tensor batch.max_seqlen = max(batch.max_seqlen, new_input_length) - if len(batch) != 1: - # Add each sequence before its padding - batch.past_key_values[i * 2] = present[:, start_index:end_index] # Cumulative sum batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length + + batch.input_ids = next_input_ids + # No need to return a batch if we know that all requests stopped return generations, batch if not stopped else None