From ad66f6ef9ac3677e6259b85026f911a555970801 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 9 May 2023 18:26:19 +0200 Subject: [PATCH] feat(server): optim flash causal lm decode_token (#285) --- .../custom_modeling/flash_llama_modeling.py | 7 +- .../custom_modeling/flash_neox_modeling.py | 7 +- .../flash_santacoder_modeling.py | 7 +- .../models/flash_causal_lm.py | 377 ++++++++++++------ .../models/flash_llama.py | 4 +- .../models/flash_neox.py | 2 +- .../models/flash_santacoder.py | 4 +- 7 files changed, 263 insertions(+), 145 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index de9b22da..1293124a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -554,6 +554,7 @@ class FlashLlamaModel(torch.nn.Module): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, @@ -575,15 +576,11 @@ class FlashLlamaModel(torch.nn.Module): ) ) layer_past_present_indices = None - cu_seqlens_q = None slice_past_index = len(hidden_states) # Decode else: # Create indices from cumulative sequence lengths layer_past_present_indices = cu_seqlens[1:] - 1 - cu_seqlens_q = torch.arange( - cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device - ) slice_past_index = None # Get rotary cos and sin for this forward @@ -650,6 +647,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, @@ -658,6 +656,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values, pre_allocate_past_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index cc9b292f..ae1465ab 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -617,6 +617,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values=None, pre_allocate_past_size: Optional[int] = None, @@ -638,15 +639,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ) ) layer_past_present_indices = None - cu_seqlens_q = None slice_past_index = len(hidden_states) # Decode else: # Create indices from cumulative sequence lengths layer_past_present_indices = cu_seqlens[1:] - 1 - cu_seqlens_q = torch.arange( - cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device - ) slice_past_index = None # Get rotary cos and sin for this forward @@ -726,6 +723,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, @@ -734,6 +732,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values, pre_allocate_past_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 71182f8d..20ad8385 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -484,6 +484,7 @@ class FlashSantacoderModel(nn.Module): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, @@ -507,15 +508,11 @@ class FlashSantacoderModel(nn.Module): ) ) layer_past_present_indices = None - cu_seqlens_q = None slice_past_index = len(hidden_states) # Decode else: # Create indices from cumulative sequence lengths layer_past_present_indices = cu_seqlens[1:] - 1 - cu_seqlens_q = torch.arange( - cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device - ) slice_past_index = None residual = None @@ -566,6 +563,7 @@ class FlashSantacoderForCausalLM(nn.Module): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values: Optional[torch.Tensor] = None, pre_allocate_past_size: Optional[int] = None, @@ -574,6 +572,7 @@ class FlashSantacoderForCausalLM(nn.Module): input_ids, position_ids, cu_seqlens, + cu_seqlens_q, max_s, past_key_values, pre_allocate_past_size, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 413866d1..b51a3dc6 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,6 +1,8 @@ import torch import torch.distributed +import numpy as np + from torch.nn import functional as F from dataclasses import dataclass @@ -33,12 +35,16 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping: Dict[int, int] # Decoder values - input_ids: List[torch.Tensor] - position_ids: List[torch.Tensor] + input_ids: torch.Tensor + position_ids: torch.Tensor + # cumulative sequence lengths - cu_seqlens: List[int] + cu_seqlens: torch.Tensor + # cumulative query sequence lengths, only used in decode + cu_seqlens_q: Optional[torch.Tensor] + # past key values, only used in decode + past_key_values: Optional[torch.Tensor] max_seqlen: int - past_key_values: Optional[Union[torch.Tensor, List[torch.Tensor]]] # All tokens all_input_ids: List[List[int]] @@ -53,9 +59,6 @@ class FlashCausalLMBatch(Batch): next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] - # Constant shared tensor, ref here just so that it's accessible in concatentate() - past_pad: Optional[torch.Tensor] - # Maximum number of tokens this batch will grow to max_tokens: int @@ -74,7 +77,6 @@ class FlashCausalLMBatch(Batch): tokenizer: PreTrainedTokenizerBase, device: torch.device, ) -> "FlashCausalLMBatch": - input_ids = [] position_ids = [] cu_seqlens = [0] max_seqlen = 0 @@ -83,7 +85,6 @@ class FlashCausalLMBatch(Batch): offsets = [] token_offsets = [] all_input_ids = [] - all_input_ids_tensor = [] requests_idx_mapping = {} next_token_choosers = [] @@ -109,15 +110,11 @@ class FlashCausalLMBatch(Batch): offsets.append(None) token_offsets.append(None) + all_input_ids.append(tokenized_input) - tokenized_input = torch.tensor(tokenized_input, device=device) - input_ids.append(tokenized_input) - # Position ids - position_ids.append( - torch.arange(0, input_length, dtype=torch.int32, device=device) - ) + position_ids.append(np.arange(0, input_length)) # Add cumulative lengths of all previous inputs cu_seqlens.append(cumulative_length + input_length) @@ -130,14 +127,19 @@ class FlashCausalLMBatch(Batch): max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) - all_input_ids_tensor.append( - F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens)) - ) - # Update cumulative_length += input_length max_tokens += input_length + max_new_tokens + # Create tensors on device + input_ids = torch.tensor( + np.concatenate(all_input_ids), dtype=torch.int64, device=device + ) + position_ids = torch.tensor( + np.concatenate(position_ids), dtype=torch.int32, device=device + ) + cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) + return cls( batch_id=pb.id, requests=pb.requests, @@ -145,16 +147,16 @@ class FlashCausalLMBatch(Batch): input_ids=input_ids, position_ids=position_ids, cu_seqlens=cu_seqlens, + cu_seqlens_q=None, max_seqlen=max_seqlen, past_key_values=None, input_lengths=input_lengths, offsets=offsets, token_offsets=token_offsets, all_input_ids=all_input_ids, - all_input_ids_tensor=all_input_ids_tensor, + all_input_ids_tensor=[], next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, - past_pad=None, max_tokens=max_tokens, ) @@ -174,9 +176,13 @@ class FlashCausalLMBatch(Batch): # New values after filtering requests_idx_mapping = {} - input_ids = [] - position_ids = [] - cu_seqlens = [0] + input_ids = self.input_ids.new_empty(len(requests)) + position_ids = self.position_ids.new_empty(len(requests)) + # Create on CPU to only move to GPU once instead of at every copy + cu_seqlens = torch.zeros(len(requests) + 1, dtype=torch.int32) + cu_seqlens_q = torch.arange( + 0, len(requests) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32 + ) max_seqlen = 0 past_key_values = [] @@ -199,16 +205,18 @@ class FlashCausalLMBatch(Batch): # Get length request_input_length = self.input_lengths[idx] - input_ids.append(self.input_ids[idx]) - position_ids.append(self.position_ids[idx]) - cu_seqlens.append(cumulative_length + request_input_length) - max_seqlen = max(max_seqlen, request_input_length) - # True index for past - past_key_values.append(self.past_key_values[2 * idx]) + # Copy tensors (GPU) + input_ids[i] = self.input_ids[idx] + position_ids[i] = self.position_ids[idx] - if not single_request: - # Add one padding - past_key_values.append(self.past_pad) + # Copy to tensor (CPU) + cu_seqlens[i + 1] = cumulative_length + request_input_length + max_seqlen = max(max_seqlen, request_input_length) + + # Slice from past + past_key_values.append( + self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]] + ) all_input_ids.append(self.all_input_ids[idx]) all_input_ids_tensor.append(self.all_input_ids_tensor[idx]) @@ -229,7 +237,7 @@ class FlashCausalLMBatch(Batch): if single_request: # Preallocate tensor for bs = 1 case - past_key_values = torch.nn.functional.pad( + past_key_values = F.pad( past_key_values[0], ( 0, @@ -243,15 +251,21 @@ class FlashCausalLMBatch(Batch): - stopping_criterias[0].current_tokens, ), ) + else: + # Cat all past + past_key_values = torch.cat(past_key_values, dim=1) + + # Move to GPU now that we have the whole tensor + cu_seqlens = cu_seqlens.to(self.cu_seqlens.device) return FlashCausalLMBatch( batch_id=self.batch_id, - past_pad=self.past_pad, requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, cu_seqlens=cu_seqlens, + cu_seqlens_q=cu_seqlens_q, max_seqlen=max_seqlen, past_key_values=past_key_values, input_lengths=input_lengths, @@ -271,9 +285,16 @@ class FlashCausalLMBatch(Batch): requests = [] requests_idx_mapping = {} - input_ids = [] - position_ids = [] + total_batch_size = sum([len(b) for b in batches]) + + device = batches[0].input_ids.device + + input_ids = batches[0].input_ids.new_empty(total_batch_size) + position_ids = batches[0].position_ids.new_empty(total_batch_size) cu_seqlens = [0] + cu_seqlens_q = torch.arange( + 0, total_batch_size + 1, device=device, dtype=torch.int32 + ) max_seqlen = 0 past_key_values = [] @@ -302,22 +323,25 @@ class FlashCausalLMBatch(Batch): for k, v in batch.requests_idx_mapping.items(): requests_idx_mapping[k] = v + cumulative_batch_size - input_ids.extend(batch.input_ids) - position_ids.extend(batch.position_ids) + start_index = cumulative_batch_size + end_index = cumulative_batch_size + len(batch) + + # Copy tensors (GPU) + input_ids[start_index:end_index] = batch.input_ids + position_ids[start_index:end_index] = batch.position_ids + # Add cumulative lengths of all previous inputs cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]]) max_seqlen = max(max_seqlen, batch.max_seqlen) if len(batch) != 1: - past_key_values.extend(batch.past_key_values) + past_key_values.append(batch.past_key_values) else: # past was pre-allocated for this batch # We need to slice to remove the padding past_key_values.append( batch.past_key_values[:, : batch.input_lengths[0]] ) - # Add one padding - past_key_values.append(batch.past_pad) all_input_ids.extend(batch.all_input_ids) all_input_ids_tensor.extend(batch.all_input_ids_tensor) @@ -334,14 +358,19 @@ class FlashCausalLMBatch(Batch): cumulative_batch_size += len(batch) max_tokens += batch.max_tokens + # Cat past + past_key_values = torch.cat(past_key_values, dim=1) + # Create final tensor on GPU + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + return FlashCausalLMBatch( batch_id=batches[0].batch_id, - past_pad=batches[0].past_pad, requests=requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, position_ids=position_ids, cu_seqlens=cu_seqlens, + cu_seqlens_q=cu_seqlens_q, max_seqlen=max_seqlen, past_key_values=past_key_values, input_lengths=input_lengths, @@ -367,10 +396,9 @@ class FlashCausalLM(Model): quantize: bool = False, decode_buffer: int = 3, ): - self.past_pad = None if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 else: raise NotImplementedError("FlashCausalLM is only available on GPU") @@ -410,6 +438,7 @@ class FlashCausalLM(Model): input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlens: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor], max_s: int, past_key_values: Optional = None, pre_allocate_past_size: Optional[int] = None, @@ -419,6 +448,7 @@ class FlashCausalLM(Model): input_ids=input_ids, position_ids=position_ids, cu_seqlens=cu_seqlens, + cu_seqlens_q=cu_seqlens_q, max_s=max_s, past_key_values=past_key_values, pre_allocate_past_size=pre_allocate_past_size, @@ -428,22 +458,9 @@ class FlashCausalLM(Model): def generate_token( self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: - # Shortcut when batch_size == 1 - if len(batch) == 1: - input_ids = batch.input_ids[0].view(-1) - # No need to slice as flash attention will take care of it with cu_seqlens - past_key_values = batch.past_key_values - else: - # Concatenate tensors - input_ids = torch.cat(batch.input_ids).view(-1) - past_key_values = ( - torch.cat(batch.past_key_values, dim=1) - if batch.past_key_values is not None - else None - ) + prefill = batch.past_key_values is None - # if prefill and bs == 1 - if past_key_values is None and len(batch) == 1: + if prefill and len(batch) == 1: # Ask to pre-allocate kv to its max size # == number of tokens + max_new_tokens pre_allocate_past_size = ( @@ -452,42 +469,74 @@ class FlashCausalLM(Model): else: pre_allocate_past_size = None - # Concatenate when prefill, torch.tensor when decode - position_ids = ( - torch.tensor(batch.position_ids, device=self.device) - if batch.past_key_values is not None - else torch.cat(batch.position_ids) - ) - cu_seqlens = torch.tensor( - batch.cu_seqlens, device=self.device, dtype=torch.int32 - ) - out, present = self.forward( - input_ids, - position_ids, - cu_seqlens, + batch.input_ids, + batch.position_ids, + batch.cu_seqlens, + batch.cu_seqlens_q, batch.max_seqlen, - past_key_values, + batch.past_key_values, pre_allocate_past_size, ) - # Initialize past_key_values in prefill - if batch.past_key_values is None: - # Initialize past padding tensor - if self.past_pad is None: - self.past_pad = present.new_zeros( - present.shape[0], 1, *present.shape[2:] + if prefill: + if len(batch) > 1: + # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs + # When batch == 1, we will just use the batch.input_ids values directly + prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids)) + + # Create batch.cu_seqlens_q for decode + batch.cu_seqlens_q = torch.arange( + 0, len(batch) + 1, device=self.device, dtype=torch.int32 + ) + next_input_ids = batch.input_ids.new_empty(len(batch)) + next_position_ids = batch.position_ids.new_empty(len(batch)) + else: + prefill_logprobs = None + next_input_ids = batch.input_ids + next_position_ids = batch.position_ids + + next_token_logprobs = out.new_empty(len(batch)) + + # Prepare past for next decode + if len(batch) > 1: + # Used to slice next batch past + past_indices = torch.empty( + present.shape[1], dtype=torch.int64, device=self.device + ) + batch.past_key_values = present.new_empty( + ( + present.shape[0], + present.shape[1] + len(batch.requests), + *present.shape[2:], ) - # Set in batch in case it needs to be used later in concatenate() - batch.past_pad = self.past_pad - if len(batch) == 1: - # present is already pre-padded - batch.past_key_values = present - else: - # Add padding after each sequence - # This will have the correct shape after the final past_key_values concatenation before the model - # forward - batch.past_key_values = [None, self.past_pad] * len(batch) + ) + + # It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow + # and will run asynchronously while we do the next for loop + cumulative_length = 0 + for i, input_length in enumerate(batch.input_lengths): + # Indexing metadata + start_index = cumulative_length + end_index = cumulative_length + input_length + + # Indices to copy present at the correct place in past_key_values + torch.arange( + start_index + i, + end_index + i, + dtype=torch.int64, + device=self.device, + out=past_indices[start_index:end_index], + ) + cumulative_length += input_length + + # Copy from present to past_key_values + batch.past_key_values[:, past_indices] = present + + # Initialize past_key_values in prefill for len(batch) == 1 + elif prefill: + # present is already pre-padded + batch.past_key_values = present # Cumulative length cumulative_length = 0 @@ -496,6 +545,102 @@ class FlashCausalLM(Model): generations: List[Generation] = [] stopped = True + # Zipped iterator + iterator = zip( + batch.input_lengths, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # We do two for loops as the first one can run completely asynchronously from the GPU while for the second + # one, we need to first do a GPU <-> CPU sync + # It is faster if we delay this sync for the maximum amount of time + + # For each member of the batch + for i, ( + input_length, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Indexing metadata + start_index = cumulative_length + end_index = cumulative_length + input_length + + if prefill: + # Prefill mode + # out is of shape [cumulative_sequence_lengths, vocab_size] + # only take last token logit + logits = out[end_index - 1 : end_index] + + # Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty) + all_input_ids_tensor = batch.input_ids.new_empty( + input_length + stopping_criteria.max_new_tokens + ) + # Copy from batch.input_ids to all_input_ids_tensor + all_input_ids_tensor[:input_length] = batch.input_ids[ + start_index:end_index + ] + batch.all_input_ids_tensor.append(all_input_ids_tensor) + + # Initialize position_ids + # In decode, we do not need this as we can just increment position ids + next_position_ids[i] = batch.position_ids[end_index - 1] + + # Used to gather prefill logprobs + # Copy batch.input_ids to prefill_token_indices + if len(batch) > 1: + prefill_tokens_indices[ + start_index : end_index - 1 + ] = batch.input_ids[start_index + 1 : end_index] + else: + # Set prefill_tokens_indices to the correct slice + prefill_tokens_indices = batch.input_ids[ + start_index + 1 : end_index + ] + else: + # Decode mode + # out is of shape [batch_size, vocab_size] + logits = out[i].view(1, -1) + + all_input_ids_tensor = batch.all_input_ids_tensor[i] + + # Select next token + next_token_id, logprob = next_token_chooser( + all_input_ids_tensor[None, :input_length], logits + ) + + # Add to all_input_ids_tensor + next_token_id_squeezed = next_token_id.view(1) + all_input_ids_tensor[input_length] = next_token_id_squeezed + + # Set values + next_input_ids[i] = next_token_id_squeezed + next_token_logprobs[i] = logprob[-1, next_token_id].view(1) + + cumulative_length += input_length + + # Set values in batch + batch.input_ids = next_input_ids + batch.position_ids = next_position_ids + 1 + batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q + + if prefill: + # Get prefill logprobs + prefill_logprobs_tensor = torch.log_softmax(out, -1) + prefill_logprobs = torch.gather( + prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) + ) + # GPU <-> CPU sync + prefill_logprobs = prefill_logprobs.view(-1).tolist() + + # GPU <-> CPU sync + next_token_logprobs = next_token_logprobs.tolist() + next_token_ids = batch.input_ids.tolist() + + cumulative_length = 0 + # Zipped iterator iterator = zip( batch.requests, @@ -506,6 +651,8 @@ class FlashCausalLM(Model): batch.stopping_criterias, batch.all_input_ids, batch.all_input_ids_tensor, + next_token_ids, + next_token_logprobs, ) # For each member of the batch @@ -518,34 +665,16 @@ class FlashCausalLM(Model): stopping_criteria, all_input_ids, all_input_ids_tensor, + next_token_id, + next_token_logprob, ) in enumerate(iterator): - # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length - prefill = stopping_criteria.current_tokens == 0 - if prefill: - # Prefill mode - # out is of shape [cumulative_sequence_lengths, vocab_size] - logits = out[start_index:end_index] - else: - # Decode mode - # out is of shape [batch_size, vocab_size] - logits = out[i].unsqueeze(0) - - # Select next token - next_token_id, logprobs = next_token_chooser( - all_input_ids_tensor[None, :input_length], logits - ) - next_token_id_squeezed = next_token_id.squeeze() - next_token_id_item = next_token_id_squeezed.item() - # Append next token to all tokens - all_input_ids.append(next_token_id_item) - all_input_ids_tensor[input_length] = next_token_id_item + all_input_ids.append(next_token_id) # Generated token - next_token_logprob = logprobs[-1, next_token_id_item] next_token_text, offset, token_offset = self.decode_token( all_input_ids, offset, @@ -554,7 +683,7 @@ class FlashCausalLM(Model): # Evaluate stopping criteria stop, reason = stopping_criteria( - next_token_id_item, + next_token_id, next_token_text, ) @@ -579,9 +708,9 @@ class FlashCausalLM(Model): # Prefill if prefill: # Remove generated token to only have prefill and add nan for first prompt token - prefill_logprobs = [float("nan")] + logprobs.gather( - 1, all_input_ids_tensor[1:input_length].unsqueeze(1) - ).squeeze(1)[:-1].tolist() + request_prefill_logprobs = [float("nan")] + prefill_logprobs[ + start_index : end_index - 1 + ] prefill_token_ids = all_input_ids[:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, @@ -589,7 +718,7 @@ class FlashCausalLM(Model): skip_special_tokens=False, ) prefill_tokens = PrefillTokens( - prefill_token_ids, prefill_logprobs, prefill_texts + prefill_token_ids, request_prefill_logprobs, prefill_texts ) else: prefill_tokens = None @@ -597,31 +726,23 @@ class FlashCausalLM(Model): generation = Generation( request.id, prefill_tokens, - next_token_id_item, + next_token_id, next_token_logprob, next_token_text, - next_token_id_item in self.all_special_ids, + next_token_id in self.all_special_ids, generated_text, ) generations.append(generation) - cumulative_length += input_length new_input_length = input_length + 1 # Update values - batch.input_ids[i] = next_token_id - batch.position_ids[i] = input_length batch.input_lengths[i] = new_input_length batch.offsets[i] = offset batch.token_offsets[i] = token_offset batch.all_input_ids[i] = all_input_ids - batch.all_input_ids_tensor[i] = all_input_ids_tensor - batch.max_seqlen = max(batch.max_seqlen, new_input_length) - if len(batch) != 1: - # Add each sequence before its padding - batch.past_key_values[i * 2] = present[:, start_index:end_index] - # Cumulative sum - batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length + batch.max_seqlen = batch.max_seqlen + 1 + cumulative_length += input_length # No need to return a batch if we know that all requests stopped return generations, batch if not stopped else None diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 105ff519..e4426771 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -32,7 +32,7 @@ class FlashLlama(FlashCausalLM): self.past_pad = None if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 else: raise NotImplementedError("FlashLlama is only available on GPU") @@ -161,7 +161,7 @@ class FlashLlamaSharded(FlashLlama): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index fc769583..f439e812 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -38,7 +38,7 @@ class FlashNeoXSharded(FlashNeoX): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 333180e8..f0825ab9 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -31,7 +31,7 @@ class FlashSantacoder(FlashCausalLM): self.past_pad = None if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 else: raise NotImplementedError("FlashSantacoder is only available on GPU") @@ -178,7 +178,7 @@ class FlashSantacoderSharded(FlashSantacoder): self.master = self.rank == 0 if torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") - dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + dtype = torch.float16 else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU")