From f6df8db68049b203129e65423a872c4cfc1c3525 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 4 May 2023 19:37:12 +0200 Subject: [PATCH] wip --- .../flash_santacoder_modeling.py | 7 +- .../models/flash_causal_lm.py | 212 +++++++++++------- 2 files changed, 133 insertions(+), 86 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 71182f8d..20ad8385 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -484,6 +484,7 @@ class FlashSantacoderModel(nn.Module): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, @@ -507,15 +508,11 @@ class FlashSantacoderModel(nn.Module): ) ) layer_past_present_indices = None - cu_seqlens_q = None slice_past_index = len(hidden_states) # Decode else: # Create indices from cumulative sequence lengths layer_past_present_indices = cu_seqlens[1:] - 1 - cu_seqlens_q = torch.arange( - cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device - ) slice_past_index = None residual = None @@ -566,6 +563,7 @@ class FlashSantacoderForCausalLM(nn.Module): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, @@ -574,6 +572,7 @@ class FlashSantacoderForCausalLM(nn.Module): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values, pre_allocate_past_size, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 29cc9848..d36acb84 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -39,6 +39,7 @@ class FlashCausalLMBatch(Batch): position_ids: torch.Tensor # cumulative sequence lengths cu_seqlens: torch.Tensor + cu_seqlens_q: Optional[torch.Tensor] max_seqlen: int past_key_values: Optional[torch.Tensor] @@ -68,10 +69,10 @@ class FlashCausalLMBatch(Batch): @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "FlashCausalLMBatch": position_ids = [] cu_seqlens = [0] @@ -127,11 +128,13 @@ class FlashCausalLMBatch(Batch): cumulative_length += input_length max_tokens += input_length + max_new_tokens - input_ids = torch.tensor(np.concatenate(all_input_ids), dtype=torch.int32, device=device) - position_ids = torch.tensor(np.concatenate(position_ids), dtype=torch.int32, device=device) - cu_seqlens = torch.tensor( - cu_seqlens, device=device, dtype=torch.int32 + input_ids = torch.tensor( + np.concatenate(all_input_ids), dtype=torch.int32, device=device ) + position_ids = torch.tensor( + np.concatenate(position_ids), dtype=torch.int32, device=device + ) + cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) return cls( batch_id=pb.id, @@ -140,6 +143,7 @@ class FlashCausalLMBatch(Batch): input_ids=input_ids, position_ids=position_ids, cu_seqlens=cu_seqlens, + cu_seqlens_q=None, max_seqlen=max_seqlen, past_key_values=None, input_lengths=input_lengths, @@ -218,7 +222,7 @@ class FlashCausalLMBatch(Batch): cumulative_length += request_input_length max_tokens += request_input_length + ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) if single_request: @@ -354,12 +358,12 @@ class FlashCausalLMBatch(Batch): class FlashCausalLM(Model): def __init__( - self, - model_cls: Type[PreTrainedModel], - model_id: str, - revision: Optional[str] = None, - quantize: bool = False, - decode_buffer: int = 3, + self, + model_cls: Type[PreTrainedModel], + model_id: str, + revision: Optional[str] = None, + quantize: bool = False, + decode_buffer: int = 3, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -399,19 +403,21 @@ class FlashCausalLM(Model): ) def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlens: torch.Tensor, - max_s: int, - past_key_values: Optional = None, - pre_allocate_past_size: Optional[int] = None, + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor], + max_s: int, + past_key_values: Optional = None, + pre_allocate_past_size: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward return self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlens=cu_seqlens, + cu_seqlens_q=cu_seqlens_q, max_s=max_s, past_key_values=past_key_values, pre_allocate_past_size=pre_allocate_past_size, @@ -419,16 +425,16 @@ class FlashCausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: FlashCausalLMBatch + self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: - # Shortcut when batch_size == 1 + prefill = batch.past_key_values is None - # if prefill and bs == 1 - if batch.past_key_values is None and len(batch) == 1: + # Shortcut when batch_size == 1 + if prefill and len(batch) == 1: # Ask to pre-allocate kv to its max size # == number of tokens + max_new_tokens pre_allocate_past_size = ( - batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens + batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens ) else: pre_allocate_past_size = None @@ -437,11 +443,23 @@ class FlashCausalLM(Model): batch.input_ids, batch.position_ids, batch.cu_seqlens, + batch.cu_seqlens_q, batch.max_seqlen, batch.past_key_values, pre_allocate_past_size, ) + if prefill: + # Compute logprobs for the whole batch + prefill_logprobs_tensor = torch.log_softmax(out, -1) + else: + prefill_logprobs_tensor = None + + # Used to slice next batch past + past_indices = [] + prefill_logprobs = [] + next_token_logprobs = [] + # Cumulative length cumulative_length = 0 @@ -451,28 +469,18 @@ class FlashCausalLM(Model): # Zipped iterator iterator = zip( - batch.requests, batch.input_lengths, - batch.offsets, - batch.token_offsets, batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, ) - past_indices = [] - - prefill = batch.past_key_values is None - # For each member of the batch for i, ( - request, - input_length, - offset, - token_offset, - next_token_chooser, - stopping_criteria, - all_input_ids, + input_length, + next_token_chooser, + stopping_criteria, + all_input_ids, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length @@ -481,21 +489,34 @@ class FlashCausalLM(Model): if prefill: # Prefill mode # out is of shape [cumulative_sequence_lengths, vocab_size] - logits = out[start_index:end_index] - batch.all_input_ids_tensor.append( - F.pad(batch.input_ids[start_index:end_index], (0, stopping_criteria.max_new_tokens)) + # only take last token logit + logits = out[end_index - 1 : end_index] + all_input_ids_tensor = F.pad( + batch.input_ids[start_index:end_index], + (0, stopping_criteria.max_new_tokens), ) + batch.all_input_ids_tensor.append(all_input_ids_tensor) + batch.position_ids[i] = input_length + prefill_logprobs.append( + prefill_logprobs_tensor[start_index:end_index] + .gather( + 1, + all_input_ids_tensor[1:input_length] + .unsqueeze(1) + .to(torch.int64), + ) + .squeeze(1)[:-1] + ) else: # Decode mode # out is of shape [batch_size, vocab_size] logits = out[i].unsqueeze(0) - all_input_ids_tensor = batch.all_input_ids_tensor[i] # Select next token - next_token_id, logprobs = next_token_chooser( + next_token_id, logprob = next_token_chooser( all_input_ids_tensor[None, :input_length], logits ) next_token_id_squeezed = next_token_id.squeeze() @@ -503,27 +524,49 @@ class FlashCausalLM(Model): past_indices.extend([j for j in range(start_index + i, end_index + i)]) batch.input_ids[i] = next_token_id_squeezed + next_token_logprobs.append(logprob[-1, next_token_id]) + cumulative_length += input_length if prefill: - batch.input_ids = batch.input_ids[:len(batch)] - batch.position_ids = batch.position_ids[:len(batch)] + batch.input_ids = batch.input_ids[: len(batch)] + batch.position_ids = batch.position_ids[: len(batch)] + batch.cu_seqlens_q = torch.arange( + 0, len(batch) + 1, device=self.device, dtype=torch.int32 + ) else: batch.position_ids += 1 # Initialize past_key_values in prefill - if batch.past_key_values is None and len(batch) == 1: + if prefill and len(batch) == 1: # present is already pre-padded batch.past_key_values = present + + batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q + if len(batch) > 1: + prefill_logprobs = torch.cat(prefill_logprobs) if prefill else None + next_token_logprobs = torch.cat(next_token_logprobs) + batch.past_key_values = present.new_empty( - (present.shape[0], present.shape[1] + len(batch.requests), *present.shape[2:])) + ( + present.shape[0], + present.shape[1] + len(batch.requests), + *present.shape[2:], + ) + ) batch.past_key_values[:, past_indices] = present - batch.cu_seqlens = batch.cu_seqlens + torch.arange(0, len(batch) + 1, device=self.device, dtype=torch.int32) + prefill_logprobs = prefill_logprobs.to("cpu") if prefill else None + next_token_logprobs = next_token_logprobs.to("cpu") + else: + prefill_logprobs = prefill_logprobs[0] if prefill else None + next_token_logprobs = next_token_logprobs[0] - next_token_ids = batch.input_ids.to("cpu").detach() + next_token_ids = batch.input_ids.to("cpu") + + prefill_logprobs_cumulative_length = 0 # Zipped iterator iterator = zip( @@ -535,26 +578,29 @@ class FlashCausalLM(Model): batch.stopping_criterias, batch.all_input_ids, batch.all_input_ids_tensor, + next_token_ids, + next_token_logprobs, ) # For each member of the batch for i, ( - request, - input_length, - offset, - token_offset, - next_token_chooser, - stopping_criteria, - all_input_ids, - all_input_ids_tensor, + request, + input_length, + offset, + token_offset, + next_token_chooser, + stopping_criteria, + all_input_ids, + all_input_ids_tensor, + next_token_id, + next_token_logprob, ) in enumerate(iterator): - next_token_id_item = next_token_ids[i] + next_token_id_item = next_token_id.item() # Append next token to all tokens all_input_ids.append(next_token_id_item) # Generated token - next_token_logprob = 0.0 next_token_text, offset, token_offset = self.decode_token( all_input_ids, offset, @@ -570,7 +616,7 @@ class FlashCausalLM(Model): if stop: # Decode generated tokens output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens:] + all_input_ids[-stopping_criteria.current_tokens :] ) # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -585,36 +631,38 @@ class FlashCausalLM(Model): stopped = False generated_text = None - # # Prefill - # if prefill: - # # Remove generated token to only have prefill and add nan for first prompt token - # prefill_logprobs = [float("nan")] + logprobs.gather( - # 1, all_input_ids_tensor[1:input_length].unsqueeze(1) - # ).squeeze(1)[:-1].tolist() - # prefill_token_ids = all_input_ids[:-1] - # prefill_texts = self.tokenizer.batch_decode( - # prefill_token_ids, - # clean_up_tokenization_spaces=False, - # skip_special_tokens=False, - # ) - # prefill_tokens = PrefillTokens( - # prefill_token_ids, prefill_logprobs, prefill_texts - # ) - # else: - prefill_tokens = None + # Prefill + if prefill: + start_index = prefill_logprobs_cumulative_length + end_index = prefill_logprobs_cumulative_length + input_length - 1 + + # Remove generated token to only have prefill and add nan for first prompt token + request_prefill_logprobs = [float("nan")] + prefill_logprobs[start_index:end_index].tolist() + prefill_token_ids = all_input_ids[:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = PrefillTokens( + prefill_token_ids, request_prefill_logprobs, prefill_texts + ) + + prefill_logprobs_cumulative_length += input_length - 1 + else: + prefill_tokens = None generation = Generation( request.id, prefill_tokens, next_token_id_item, - next_token_logprob, + next_token_logprob.item(), next_token_text, next_token_id_item in self.all_special_ids, generated_text, ) generations.append(generation) - cumulative_length += input_length new_input_length = input_length + 1 # Update values