diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index de9b22da..1293124a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -554,6 +554,7 @@ class FlashLlamaModel(torch.nn.Module): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, @@ -575,15 +576,11 @@ class FlashLlamaModel(torch.nn.Module): ) ) layer_past_present_indices = None - cu_seqlens_q = None slice_past_index = len(hidden_states) # Decode else: # Create indices from cumulative sequence lengths layer_past_present_indices = cu_seqlens[1:] - 1 - cu_seqlens_q = torch.arange( - cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device - ) slice_past_index = None # Get rotary cos and sin for this forward @@ -650,6 +647,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, @@ -658,6 +656,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values, pre_allocate_past_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index cc9b292f..ae1465ab 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -617,6 +617,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values=None, pre_allocate_past_size: Optional[int] = None, @@ -638,15 +639,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ) ) layer_past_present_indices = None - cu_seqlens_q = None slice_past_index = len(hidden_states) # Decode else: # Create indices from cumulative sequence lengths layer_past_present_indices = cu_seqlens[1:] - 1 - cu_seqlens_q = torch.arange( - cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device - ) slice_past_index = None # Get rotary cos and sin for this forward @@ -726,6 +723,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, @@ -734,6 +732,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values, pre_allocate_past_size, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d36acb84..ba318d14 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -37,11 +37,14 @@ class FlashCausalLMBatch(Batch): # Decoder values input_ids: torch.Tensor position_ids: torch.Tensor + # cumulative sequence lengths cu_seqlens: torch.Tensor + # cumulative query sequence lengths, only used in decode cu_seqlens_q: Optional[torch.Tensor] - max_seqlen: int + # past key values, only used in decode past_key_values: Optional[torch.Tensor] + max_seqlen: int # All tokens all_input_ids: List[List[int]] @@ -128,8 +131,9 @@ class FlashCausalLMBatch(Batch): cumulative_length += input_length max_tokens += input_length + max_new_tokens + # Create tensors on device input_ids = torch.tensor( - np.concatenate(all_input_ids), dtype=torch.int32, device=device + np.concatenate(all_input_ids), dtype=torch.int64, device=device ) position_ids = torch.tensor( np.concatenate(position_ids), dtype=torch.int32, device=device @@ -172,9 +176,13 @@ class FlashCausalLMBatch(Batch): # New values after filtering requests_idx_mapping = {} - input_ids = [] - position_ids = [] - cu_seqlens = [0] + input_ids = self.input_ids.new_empty(len(requests)) + position_ids = self.position_ids.new_empty(len(requests)) + # Create on CPU to only move to GPU once instead of at every copy + cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32) + cu_seqlens_q = torch.arange( + 0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32 + ) max_seqlen = 0 past_key_values = [] @@ -197,16 +205,18 @@ class FlashCausalLMBatch(Batch): # Get length request_input_length = self.input_lengths[idx] - input_ids.append(self.input_ids[idx]) - position_ids.append(self.position_ids[idx]) - cu_seqlens.append(cumulative_length + request_input_length) - max_seqlen = max(max_seqlen, request_input_length) - # True index for past - past_key_values.append(self.past_key_values[2 * idx]) + # Copy tensors (GPU) + input_ids[i] = self.input_ids[idx] + position_ids[i] = self.position_ids[idx] - if not single_request: - # Add one padding - past_key_values.append(self.past_pad) + # Copy to tensor (CPU) + cu_seqlens[i + 1] = cumulative_length + request_input_length + max_seqlen = max(max_seqlen, request_input_length) + + # Slice from past + past_key_values.append( + self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]] + ) all_input_ids.append(self.all_input_ids[idx]) all_input_ids_tensor.append(self.all_input_ids_tensor[idx]) @@ -227,7 +237,7 @@ class FlashCausalLMBatch(Batch): if single_request: # Preallocate tensor for bs = 1 case - past_key_values = torch.nn.functional.pad( + past_key_values = F.pad( past_key_values[0], ( 0, @@ -241,15 +251,21 @@ class FlashCausalLMBatch(Batch): - stopping_criterias[0].current_tokens, ), ) + else: + # Cat all past + past_key_values = torch.cat(past_key_values, dim=1) + + # Move to GPU now that we have the whole tensor + cu_seqlens = cu_seqlens.to(self.cu_seqlens.device) return FlashCausalLMBatch( batch_id=self.batch_id, - past_pad=self.past_pad, requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, cu_seqlens=cu_seqlens, + cu_seqlens_q=cu_seqlens_q, max_seqlen=max_seqlen, past_key_values=past_key_values, input_lengths=input_lengths, @@ -269,9 +285,16 @@ class FlashCausalLMBatch(Batch): requests = [] requests_idx_mapping = {} - input_ids = [] - position_ids = [] + total_batch_size = sum([len(b) for b in batches]) + + device = batches[0].input_ids.device + + input_ids = batches[0].input_ids.new_empty(total_batch_size) + position_ids = batches[0].position_ids.new_empty(total_batch_size) cu_seqlens = [0] + cu_seqlens_q = torch.arange( + 0, total_batch_size + 1, device=device, dtype=torch.int32 + ) max_seqlen = 0 past_key_values = [] @@ -300,22 +323,25 @@ class FlashCausalLMBatch(Batch): for k, v in batch.requests_idx_mapping.items(): requests_idx_mapping[k] = v + cumulative_batch_size - input_ids.extend(batch.input_ids) - position_ids.extend(batch.position_ids) + start_index = cumulative_batch_size + end_index = cumulative_batch_size + len(batch) + + # Copy tensors (GPU) + input_ids[start_index:end_index] = batch.input_ids + position_ids[start_index:end_index] = batch.position_ids + # Add cumulative lengths of all previous inputs cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]]) max_seqlen = max(max_seqlen, batch.max_seqlen) if len(batch) != 1: - past_key_values.extend(batch.past_key_values) + past_key_values.append(batch.past_key_values) else: # past was pre-allocated for this batch # We need to slice to remove the padding past_key_values.append( batch.past_key_values[:, : batch.input_lengths[0]] ) - # Add one padding - past_key_values.append(batch.past_pad) all_input_ids.extend(batch.all_input_ids) all_input_ids_tensor.extend(batch.all_input_ids_tensor) @@ -332,14 +358,19 @@ class FlashCausalLMBatch(Batch): cumulative_batch_size += len(batch) max_tokens += batch.max_tokens + # Cat past + past_key_values = torch.cat(past_key_values, dim=1) + # Create final tensor on GPU + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + return FlashCausalLMBatch( batch_id=batches[0].batch_id, - past_pad=batches[0].past_pad, requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, cu_seqlens=cu_seqlens, + cu_seqlens_q=cu_seqlens_q, max_seqlen=max_seqlen, past_key_values=past_key_values, input_lengths=input_lengths, @@ -367,7 +398,7 @@ class FlashCausalLM(Model): ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 else: raise NotImplementedError("FlashCausalLM is only available on GPU") @@ -429,7 +460,6 @@ class FlashCausalLM(Model): ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: prefill = batch.past_key_values is None - # Shortcut when batch_size == 1 if prefill and len(batch) == 1: # Ask to pre-allocate kv to its max size # == number of tokens + max_new_tokens @@ -450,15 +480,62 @@ class FlashCausalLM(Model): ) if prefill: - # Compute logprobs for the whole batch - prefill_logprobs_tensor = torch.log_softmax(out, -1) - else: - prefill_logprobs_tensor = None + if len(batch) > 1: + # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs + # When batch == 1, we will just use the batch.input_ids values directly + prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids)) - # Used to slice next batch past - past_indices = [] - prefill_logprobs = [] - next_token_logprobs = [] + # Create batch.cu_seqlens_q for decode + batch.cu_seqlens_q = torch.arange( + 0, len(batch) + 1, device=self.device, dtype=torch.int32 + ) + next_input_ids = batch.input_ids.new_empty(len(batch)) + next_position_ids = batch.position_ids.new_empty(len(batch)) + else: + prefill_logprobs = None + next_input_ids = batch.input_ids + next_position_ids = batch.position_ids + + next_token_logprobs = out.new_empty(len(batch)) + + # Prepare past for next decode + if len(batch) > 1: + # Used to slice next batch past + past_indices = torch.empty( + present.shape[1], dtype=torch.int64, device=self.device + ) + batch.past_key_values = present.new_empty( + ( + present.shape[0], + present.shape[1] + len(batch.requests), + *present.shape[2:], + ) + ) + + # It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow + # and will run asynchronously while we do the next for loop + cumulative_length = 0 + for i, input_length in enumerate(batch.input_lengths): + # Indexing metadata + start_index = cumulative_length + end_index = cumulative_length + input_length + + # Indices to copy present at the correct place in past_key_values + past_indices[start_index:end_index] = torch.arange( + start_index + i, + end_index + i, + dtype=torch.int64, + device=self.device, + ) + cumulative_length += input_length + + # Copy from present to past_key_values + batch.past_key_values[:, past_indices] = present + + # Initialize past_key_values in prefill for len(batch) == 1 + elif prefill: + # present is already pre-padded + batch.past_key_values = present # Cumulative length cumulative_length = 0 @@ -475,6 +552,10 @@ class FlashCausalLM(Model): batch.all_input_ids, ) + # We do two for loops as the first one can run completely asynchronously from the GPU while for the second + # one, we need to first do a GPU <-> CPU sync + # It is faster if we delay this sync for the maximum amount of time + # For each member of the batch for i, ( input_length, @@ -491,23 +572,32 @@ class FlashCausalLM(Model): # out is of shape [cumulative_sequence_lengths, vocab_size] # only take last token logit logits = out[end_index - 1 : end_index] - all_input_ids_tensor = F.pad( - batch.input_ids[start_index:end_index], - (0, stopping_criteria.max_new_tokens), + + # Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty) + all_input_ids_tensor = batch.input_ids.new_empty( + input_length + stopping_criteria.max_new_tokens ) + # Copy from batch.input_ids to all_input_ids_tensor + all_input_ids_tensor[:input_length] = batch.input_ids[ + start_index:end_index + ] batch.all_input_ids_tensor.append(all_input_ids_tensor) - batch.position_ids[i] = input_length - prefill_logprobs.append( - prefill_logprobs_tensor[start_index:end_index] - .gather( - 1, - all_input_ids_tensor[1:input_length] - .unsqueeze(1) - .to(torch.int64), - ) - .squeeze(1)[:-1] - ) + # Initialize position_ids + # In decode, we do not need this as we can just increment position ids + next_position_ids[i] = batch.position_ids[end_index - 1] + + # Used to gather prefill logprobs + # Copy batch.input_ids to prefill_token_indices + if len(batch) > 1: + prefill_tokens_indices[ + start_index : end_index - 1 + ] = batch.input_ids[start_index + 1 : end_index] + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = batch.input_ids[ + start_index + 1 : end_index + ] else: # Decode mode # out is of shape [batch_size, vocab_size] @@ -519,54 +609,36 @@ class FlashCausalLM(Model): next_token_id, logprob = next_token_chooser( all_input_ids_tensor[None, :input_length], logits ) + + # Add to all_input_ids_tensor next_token_id_squeezed = next_token_id.squeeze() all_input_ids_tensor[input_length] = next_token_id_squeezed - past_indices.extend([j for j in range(start_index + i, end_index + i)]) - batch.input_ids[i] = next_token_id_squeezed - next_token_logprobs.append(logprob[-1, next_token_id]) + # Set values + next_input_ids[i] = next_token_id_squeezed + next_token_logprobs[i] = logprob[-1, next_token_id].view(1) cumulative_length += input_length - if prefill: - batch.input_ids = batch.input_ids[: len(batch)] - batch.position_ids = batch.position_ids[: len(batch)] - batch.cu_seqlens_q = torch.arange( - 0, len(batch) + 1, device=self.device, dtype=torch.int32 - ) - else: - batch.position_ids += 1 - - # Initialize past_key_values in prefill - if prefill and len(batch) == 1: - # present is already pre-padded - batch.past_key_values = present - - + # Set values in batch + batch.input_ids = next_input_ids + batch.position_ids = next_position_ids + 1 batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q - if len(batch) > 1: - prefill_logprobs = torch.cat(prefill_logprobs) if prefill else None - next_token_logprobs = torch.cat(next_token_logprobs) - - batch.past_key_values = present.new_empty( - ( - present.shape[0], - present.shape[1] + len(batch.requests), - *present.shape[2:], - ) + if prefill: + # Get prefill logprobs + prefill_logprobs_tensor = torch.log_softmax(out, -1) + prefill_logprobs = torch.gather( + prefill_logprobs_tensor, 1, prefill_tokens_indices.unsqueeze(1) ) - batch.past_key_values[:, past_indices] = present + # GPU <-> CPU sync + prefill_logprobs = prefill_logprobs.squeeze(1).to("cpu").numpy() - prefill_logprobs = prefill_logprobs.to("cpu") if prefill else None - next_token_logprobs = next_token_logprobs.to("cpu") - else: - prefill_logprobs = prefill_logprobs[0] if prefill else None - next_token_logprobs = next_token_logprobs[0] + # GPU <-> CPU sync + next_token_logprobs = next_token_logprobs.to("cpu").numpy() + next_token_ids = batch.input_ids.to("cpu").numpy() - next_token_ids = batch.input_ids.to("cpu") - - prefill_logprobs_cumulative_length = 0 + cumulative_length = 0 # Zipped iterator iterator = zip( @@ -595,10 +667,11 @@ class FlashCausalLM(Model): next_token_id, next_token_logprob, ) in enumerate(iterator): - next_token_id_item = next_token_id.item() + start_index = cumulative_length + end_index = cumulative_length + input_length # Append next token to all tokens - all_input_ids.append(next_token_id_item) + all_input_ids.append(next_token_id) # Generated token next_token_text, offset, token_offset = self.decode_token( @@ -609,7 +682,7 @@ class FlashCausalLM(Model): # Evaluate stopping criteria stop, reason = stopping_criteria( - next_token_id_item, + next_token_id, next_token_text, ) @@ -633,11 +706,10 @@ class FlashCausalLM(Model): # Prefill if prefill: - start_index = prefill_logprobs_cumulative_length - end_index = prefill_logprobs_cumulative_length + input_length - 1 - # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = [float("nan")] + prefill_logprobs[start_index:end_index].tolist() + request_prefill_logprobs = [float("nan")] + prefill_logprobs[ + start_index : end_index - 1 + ].tolist() prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, @@ -647,18 +719,16 @@ class FlashCausalLM(Model): prefill_tokens = PrefillTokens( prefill_token_ids, request_prefill_logprobs, prefill_texts ) - - prefill_logprobs_cumulative_length += input_length - 1 else: prefill_tokens = None generation = Generation( request.id, prefill_tokens, - next_token_id_item, - next_token_logprob.item(), + next_token_id, + next_token_logprob, next_token_text, - next_token_id_item in self.all_special_ids, + next_token_id in self.all_special_ids, generated_text, ) @@ -670,7 +740,8 @@ class FlashCausalLM(Model): batch.offsets[i] = offset batch.token_offsets[i] = token_offset batch.all_input_ids[i] = all_input_ids - batch.max_seqlen = max(batch.max_seqlen, new_input_length) + batch.max_seqlen = batch.max_seqlen + 1 + cumulative_length += input_length # No need to return a batch if we know that all requests stopped return generations, batch if not stopped else None diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index e640113b..9ef66f7f 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -32,7 +32,7 @@ class FlashLlama(FlashCausalLM): self.past_pad = None if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 else: raise NotImplementedError("FlashLlama is only available on GPU") @@ -152,7 +152,7 @@ class FlashLlamaSharded(FlashLlama): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index eae584ac..9759df94 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -38,7 +38,7 @@ class FlashNeoXSharded(FlashNeoX): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index aa1bdfb5..5d3cce6a 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -31,7 +31,7 @@ class FlashSantacoder(FlashCausalLM): self.past_pad = None if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 else: raise NotImplementedError("FlashSantacoder is only available on GPU") @@ -178,7 +178,7 @@ class FlashSantacoderSharded(FlashSantacoder): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU")