diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4cc285bf..8a9512c9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -64,6 +64,8 @@ tracer = trace.get_tracer(__name__) # Will be set in init SLIDING_WINDOW: Optional[int] = None +TOKEN_BUDGET = 8 + def set_sliding_window(sliding_window: int): global SLIDING_WINDOW @@ -144,12 +146,14 @@ class FlashCausalLMBatch(Batch): # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences slots: torch.Tensor - max_seqlen: int + max_postfix_length: int + max_current_length: int # Prefill metadata tensors to efficiently compute logprobs prefill_head_indices: Optional[torch.Tensor] prefill_next_token_indices: Optional[torch.tensor] prefill_cu_outlens: Optional[List[int]] + prefill_tokens: List[Optional[Tokens]] # Prefixes prefix_ids: List[List[int]] @@ -257,7 +261,8 @@ class FlashCausalLMBatch(Batch): prefill_out_cumulative_length = 0 num_blocks = 0 - max_seqlen = 0 + max_postfix_length = 0 + max_current_length = 0 max_length = 0 max_blocks = 0 @@ -285,20 +290,21 @@ class FlashCausalLMBatch(Batch): # Commented as it's costly. # log_master(logger.debug, "Tokenized input ids {tokenized_input}") prefix_ids.append(tokenized_input[:prefix_length]) - postfix_ids = tokenized_input[prefix_length:] + postfix_ids = tokenized_input[prefix_length : prefix_length + 10] + # postfix_ids = tokenized_input[prefix_length:] postfix_length = len(postfix_ids) postfix_lengths.append(postfix_length) - prefix_offsets.append(postfix_length - 5) - read_offsets.append(postfix_length) + prefix_offsets.append(prompt_length - 5) + read_offsets.append(prompt_length) all_postfix_ids.append(postfix_ids) all_input_ids.append(tokenized_input) # Position ids request_position_ids = torch.arange( - prefix_length, prompt_length, dtype=torch.int32 + prefix_length, prefix_length + postfix_length, dtype=torch.int32 ) position_ids.append(request_position_ids) @@ -396,11 +402,12 @@ class FlashCausalLMBatch(Batch): # Update cumulative_length += postfix_length cumulative_slot_tokens += slot_tokens - max_seqlen = max(max_seqlen, postfix_length) max_blocks = max(max_blocks, len(request_blocks)) + max_postfix_length = max(max_postfix_length, postfix_length) + max_current_length = max(max_current_length, prefix_length + postfix_length) max_length = max( max_length, - prefix_length + postfix_length + max_new_tokens + speculative_length, + prompt_length + max_new_tokens + speculative_length, ) adapter_indices = torch.cat(adapter_indices_list).to( @@ -502,10 +509,12 @@ class FlashCausalLMBatch(Batch): slots=slots, prefix_lengths=prefix_lengths, prefix_lengths_tensor=prefix_lengths_tensor, - max_seqlen=max_seqlen, + max_postfix_length=max_postfix_length, + max_current_length=max_current_length, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, + prefill_tokens=[None] * len(pb.requests), postfix_lengths=postfix_lengths, postfix_lengths_tensor=postfix_lengths_tensor, prompt_lengths=prompt_lengths, @@ -565,7 +574,8 @@ class FlashCausalLMBatch(Batch): # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) - max_seqlen = 0 + max_postfix_length = 0 + max_current_length = 0 requests = [] start_slots = [] @@ -579,6 +589,7 @@ class FlashCausalLMBatch(Batch): prefix_offsets = [] read_offsets = [] + prefill_tokens = [] stopping_criterias = [] top_n_tokens = [] @@ -598,15 +609,18 @@ class FlashCausalLMBatch(Batch): # Get length request_postfix_length = self.postfix_lengths[idx] - prefix_length = self.prefix_lengths[idx] - max_seqlen = max(max_seqlen, request_postfix_length) + request_prefix_length = self.prefix_lengths[idx] + max_postfix_length = max(max_postfix_length, request_postfix_length) + max_current_length = max( + max_current_length, request_prefix_length + request_postfix_length + ) all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) prompt_lengths.append(self.prompt_lengths[idx]) postfix_lengths.append(request_postfix_length) - prefix_lengths.append(prefix_length) + prefix_lengths.append(request_prefix_length) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -614,6 +628,7 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(self.top_n_tokens[idx]) + prefill_tokens.append(self.prefill_tokens[idx]) ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0) @@ -683,10 +698,12 @@ class FlashCausalLMBatch(Batch): block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, - max_seqlen=max_seqlen, + max_postfix_length=max_postfix_length, + max_current_length=max_current_length, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_tokens=prefill_tokens, prompt_lengths=prompt_lengths, prompt_lengths_tensor=prompt_lengths_tensor, postfix_lengths=postfix_lengths, @@ -725,7 +742,8 @@ class FlashCausalLMBatch(Batch): total_slots = 0 max_blocks = 0 max_length = 0 - max_seqlen = 0 + max_postfix_length = 0 + max_current_length = 0 for b in batches: total_batch_size += len(b) total_slots += len(b.slots) @@ -734,7 +752,8 @@ class FlashCausalLMBatch(Batch): b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) max_blocks = max(max_blocks, b.max_blocks) - max_seqlen = max(max_seqlen, b.max_seqlen) + max_postfix_length = max(max_postfix_length, b.max_postfix_length) + max_current_length = max(max_current_length, b.max_current_length) max_length = max( max_length, max( @@ -791,6 +810,8 @@ class FlashCausalLMBatch(Batch): prefix_offsets = [] read_offsets = [] + prefill_tokens = [] + next_token_chooser_parameters = [] fsm_grammar_states = [] stopping_criterias = [] @@ -862,6 +883,8 @@ class FlashCausalLMBatch(Batch): prefix_offsets.extend(batch.prefix_offsets) read_offsets.extend(batch.read_offsets) + prefill_tokens.extend(batch.prefill_tokens) + next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states) stopping_criterias.extend(batch.stopping_criterias) @@ -907,10 +930,12 @@ class FlashCausalLMBatch(Batch): prefix_lengths=prefix_lengths, prefix_lengths_tensor=prefix_lengths_tensor, slots=slots, - max_seqlen=max_seqlen, + max_postfix_length=max_postfix_length, + max_current_length=max_current_length, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, + prefill_tokens=prefill_tokens, prompt_lengths=prompt_lengths, prompt_lengths_tensor=prompt_lengths_tensor, postfix_lengths=postfix_lengths, @@ -1416,7 +1441,7 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] postfix_lengths = batch.postfix_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices speculative_ids = batch.speculative_ids @@ -1459,7 +1484,7 @@ class FlashCausalLM(Model): slots = batch.slots[batch.slot_indices] postfix_lengths = batch.postfix_lengths_tensor prefix_lengths_tensor = batch.prefix_lengths_tensor - max_s = batch.max_seqlen + max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices if cu_seqlen_prefill is None and self.max_past() is not None: @@ -1608,15 +1633,47 @@ class FlashCausalLM(Model): if prefill_logprobs else speculative_logits ) - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( - len(batch) - ) - + if len(batch) > 1 and prefill_logprobs: + # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs + # When batch == 1, we will just use the batch.input_ids values directly + prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) else: + prefill_logprobs = None next_token_logits = out next_adapter_indices = batch.adapter_meta.adapter_indices - speculate = get_speculate() + finished_prefilling = True + next_chunk_lengths = [] + if prefill: + # Budget in tokens for the next batch + # We remove next input ids to always have enough space for at least a single decode + # for the remaining requests + batch_budget = TOKEN_BUDGET - len(batch) + for prefix_length, postfix_length, prompt_length in zip( + batch.prefix_lengths, batch.postfix_lengths, batch.prompt_lengths + ): + remaining_prefill_tokens = max( + prompt_length - prefix_length - postfix_length, 0 + ) + if remaining_prefill_tokens > 0: + next_chunk_length = max( + min(remaining_prefill_tokens, batch_budget), 1 + ) + batch_budget -= next_chunk_length + finished_prefilling = False + else: + # Since speculation will be turned off, this is always true + next_chunk_length = 1 + next_chunk_lengths.append(next_chunk_length) + + # Turn off speculative if some requests are still prefilling + # It makes the logic easier to follow + if prefill and not finished_prefilling: + speculate = 0 + speculative_logits = None + else: + speculate = get_speculate() + ( next_input_ids, next_token_logprobs, @@ -1624,7 +1681,7 @@ class FlashCausalLM(Model): accepted_ids, speculative_ids, ) = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : max(batch.postfix_lengths)], + batch.all_input_ids_tensor[:, : batch.max_current_length], next_token_logits, speculate, batch.speculative_ids, @@ -1635,18 +1692,15 @@ class FlashCausalLM(Model): batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) - if prefill: - if len(batch) > 1 and prefill_logprobs: - # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs - # When batch == 1, we will just use the batch.input_ids values directly - prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - + # Since we are done prefilling, all the tensors that were concatenating values for all the requests + # instantly become of shape [BATCH_SIZE] + if prefill and finished_prefilling: next_position_ids = batch.position_ids.new_empty(len(batch)) batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - # We do not need cu_seqlen_prefill anymore - batch.cu_seqlen_prefill = None - else: - prefill_logprobs = None + next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( + len(batch) + ) + elif not prefill: next_position_ids = batch.position_ids # Cumulative length @@ -1658,6 +1712,7 @@ class FlashCausalLM(Model): # Zipped iterator iterator = zip( + batch.prompt_lengths, batch.prefix_lengths, batch.postfix_lengths, batch.all_input_ids, @@ -1671,6 +1726,7 @@ class FlashCausalLM(Model): # For each member of the batch index = 0 for i, ( + prompt_length, prefix_length, postfix_length, all_input_ids, @@ -1686,15 +1742,16 @@ class FlashCausalLM(Model): out_end_index = batch.prefill_cu_outlens[i + 1] out_length = out_end_index - out_start_index - # Initialize position_ids - # In decode, we do not need this as we can just increment position ids - next_position_ids[i] = batch.position_ids[end_index - 1] + if finished_prefilling: + # Initialize position_ids + # In decode, we do not need this as we can just increment position ids + next_position_ids[i] = batch.position_ids[end_index - 1] - # Initialize adapter indices - # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ - end_index - 1 - ] + # Initialize adapter indices + # In decode, we only have one token per row in the batch, so grab last index + next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ + end_index - 1 + ] # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices @@ -1709,30 +1766,29 @@ class FlashCausalLM(Model): start_index + 1 : start_index + out_length ] - for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, prefix_length + postfix_length + j] = ( - next_input_ids[index] - ) - index += 1 + # Represent whether this request is still prefilling + # If it is, the tokens we decoded should be ignored + accept_tokens = prefix_length + postfix_length >= prompt_length - cumulative_length += postfix_length + if accept_tokens: + # Only save tokens if we are done prefilling for this request + for j in range(n_accepted_ids): + batch.all_input_ids_tensor[ + i, prefix_length + postfix_length + j + ] = next_input_ids[index] + index += 1 + + cumulative_length += postfix_length # Update values - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.postfix_lengths_tensor += accepted_ids - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices - - if prefill: - # adjust segment lengths to account for all request lengths being 1 during decoding - adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) - batch.adapter_meta.adapter_segments = torch.tensor( - adapter_segments, - dtype=torch.int32, - device=batch.adapter_meta.adapter_segments.device, - ) + # These values can be updated without a GPU -> CPU sync + if not prefill or (prefill and finished_prefilling): + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.postfix_lengths_tensor += accepted_ids + batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices if prefill and prefill_logprobs: # Get prefill logprobs @@ -1743,15 +1799,265 @@ class FlashCausalLM(Model): # GPU <-> CPU sync prefill_logprobs = prefill_logprobs.view(-1).tolist() + # Does a GPU <-> CPU sync internally + if prefill and finished_prefilling: + # adjust segment lengths to account for all request lengths being 1 during decoding + adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) + batch.adapter_meta.adapter_segments = torch.tensor( + adapter_segments, + dtype=torch.int32, + device=batch.adapter_meta.adapter_segments.device, + ) + # GPU <-> CPU sync next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() accepted_ids = accepted_ids.tolist() + + # Update values if we need to continue prefilling + # This represents the `else` case of the `Update values` if above + # but since this require the `next_token_ids` to be on CPU, it is better to do it here + skip_tokens = {} + if prefill and not finished_prefilling: + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert batch.speculative_ids is None + + all_postfix_ids = [] + sliding_window = get_sliding_windows() + position_ids = [] + cu_seqlen_prefill = [0] + start_slots = [] + slot_indices = [] + prefill_cache_indices = [] + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_head_indices = [] + prefill_next_token_indices = [] + prefill_cu_outlens = [0] + + # Cumulative length + cumulative_length = 0 + cumulative_slot_tokens = 0 + prefill_out_cumulative_length = 0 + + slots = [] + adapter_indices_list = [] + + for i, ( + r, + next_token_id, + all_input_ids, + prefix_length, + postfix_length, + prompt_length, + next_chunk_length, + ) in enumerate( + zip( + batch.requests, + next_token_ids, + batch.all_input_ids, + batch.prefix_lengths, + batch.postfix_lengths, + batch.prompt_lengths, + next_chunk_lengths, + ) + ): + continue_prefilling = prefix_length + postfix_length < prompt_length + skip_tokens[r.id] = True + if continue_prefilling: + # Update prefix length + prefix_length = prefix_length + postfix_length + batch.prefix_lengths[i] = prefix_length + + # Update postfix length + postfix_length = next_chunk_length + batch.max_postfix_length = max( + batch.max_postfix_length, postfix_length + ) + batch.postfix_lengths[i] = postfix_length + + # Potentially update max_current_length + current_length = prefix_length + postfix_length + batch.max_current_length = max( + batch.max_current_length, current_length + ) + + # Get new prompt IDs to prefill + postfix_ids = all_input_ids[ + prefix_length : prefix_length + postfix_length + ] + + # Position ids + request_position_ids = torch.arange( + prefix_length, prefix_length + postfix_length, dtype=torch.int32 + ) + position_ids.append(request_position_ids) + + # Add cumulative lengths of all previous inputs + cu_seqlen_prefill.append(cumulative_length + postfix_length) + + request_slots = r.slots[prefix_length:] + request_slot_indices = torch.arange( + cumulative_slot_tokens, + cumulative_slot_tokens + postfix_length, + dtype=torch.int64, + ) + + # Create tensor to slice into the kv tensor in prefill + if sliding_window is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, postfix_length - sliding_window), + cumulative_length + postfix_length, + dtype=torch.int64, + ) + + all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs + + if r.prefill_logprobs: + prefill_head_indices.append( + request_position_ids + cumulative_length + ) + prefill_next_token_indices.append( + prefill_out_cumulative_length + postfix_length - 1 + ) + prefill_cu_outlens.append( + prefill_out_cumulative_length + postfix_length + ) + prefill_out_cumulative_length += postfix_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + postfix_length - 1], + dtype=torch.int32, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + else: + # This request is done prefilling, the new id is the one selected the sampling method + postfix_ids = [next_token_id] + + # Position_ids + position_ids.append( + torch.tensor( + (prefix_length + postfix_length,), dtype=torch.int32 + ) + ) + + # Add this request token + cu_seqlen_prefill.append(cumulative_length + 1) + + request_slots = r.slots[prefix_length:] + request_slot_indices = torch.tensor( + (cumulative_slot_tokens + postfix_length,), dtype=torch.int64 + ) + + # Create tensor to slice into the kv tensor in prefill + if sliding_window is not None: + request_prefill_cache_indices = torch.tensor( + [cumulative_length], dtype=torch.int64 + ) + + prefill_head_indices.append( + torch.tensor([cumulative_length], dtype=torch.int32) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + all_postfix_ids.extend(postfix_ids) + start_slots.append(cumulative_slot_tokens) + slots.extend(request_slots) + slot_indices.append(request_slot_indices) + + if sliding_window is not None: + prefill_cache_indices.append(request_prefill_cache_indices) + + ADAPTER_TO_INDEX = get_adapter_to_index() + adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append( + torch.full((postfix_length,), adapter_index) + ) + + # Update + cumulative_length += postfix_length + cumulative_slot_tokens += len(request_slots) + + device = batch.input_ids.device + batch.start_slots = torch.tensor(start_slots, dtype=torch.int64) + + if len(batch) > 1: + position_ids = torch.cat(position_ids) + slot_indices = torch.cat(slot_indices) + if sliding_window is not None: + prefill_cache_indices = torch.cat(prefill_cache_indices) + else: + position_ids = position_ids[0] + slot_indices = slot_indices[0] + if sliding_window is not None: + prefill_cache_indices = prefill_cache_indices[0] + + cu_seqlen_prefill = torch.tensor( + cu_seqlen_prefill, device=device, dtype=torch.int32 + ) + batch.cu_seqlen_prefill = cu_seqlen_prefill + batch.position_ids = position_ids.to(device) + batch.slot_indices = slot_indices.to(device) + batch.prefill_cache_indices = ( + prefill_cache_indices.to(device) if sliding_window is not None else None + ) + batch.input_ids = torch.tensor( + all_postfix_ids, dtype=torch.int64, device=device + ) + batch.postfix_lengths_tensor = torch.tensor( + batch.postfix_lengths, dtype=torch.int32, device=device + ) + + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.tensor( + torch.cat(prefill_head_indices), dtype=torch.int64, device=device + ) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + + batch.prefill_head_indices = prefill_head_indices + batch.prefill_next_token_indices = prefill_next_token_indices + batch.slots = torch.tensor(slots, dtype=torch.int64, device=device) + batch.prefix_lengths_tensor = torch.tensor( + batch.prefix_lengths, dtype=torch.int32, device=device + ) + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + batch.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=batch.adapter_meta.adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) + start_decode = time.time_ns() # Zipped iterator iterator = zip( batch.requests, + batch.prompt_lengths, + batch.prefix_lengths, batch.postfix_lengths, batch.prefix_offsets, batch.read_offsets, @@ -1770,7 +2076,9 @@ class FlashCausalLM(Model): index = 0 for i, ( request, - input_length, + prompt_length, + prefix_length, + postfix_length, prefix_offset, read_offset, stopping_criteria, @@ -1783,6 +2091,61 @@ class FlashCausalLM(Model): top_token_ids, top_token_logprobs, ) in enumerate(iterator): + # Compute logprobs first as, even though we might skip the token, + # it can still be required to compute the logprobs + # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need + # this state to be stable + if request.id % self.world_size == self.rank: + # Prefill + if prefill and request.prefill_logprobs: + out_start_index = batch.prefill_cu_outlens[i] + out_end_index = batch.prefill_cu_outlens[i + 1] + + request_prefill_tokens = batch.prefill_tokens[i] + + request_prefill_logprobs = prefill_logprobs[ + out_start_index : out_end_index - 1 + ] + prefill_token_ids = all_input_ids[:-1] + + if request_prefill_tokens is None: + # Remove generated token to only have prefill and add nan for first prompt token + request_prefill_logprobs = [float("nan")] * ( + len(prefix_ids) + 1 + ) + request_prefill_logprobs + prefill_token_ids = prefix_ids + prefill_token_ids + + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + + prefill_tokens = Tokens( + prefix_ids + prefill_token_ids, + request_prefill_logprobs, + prefill_texts, + is_special=[], + ) + if request_prefill_tokens is not None: + prefill_tokens = request_prefill_tokens + prefill_tokens + + batch.prefill_tokens[i] = prefill_tokens + else: + batch.prefill_tokens[i] = None + + # Represent whether this request is still prefilling + # If it is, the tokens we decoded should be ignored + skip_token = skip_tokens.get(request.id, False) + + if skip_token: + # Make sure that we do not stop as even though this request did not create a token, it is still + # processing + stopped = False + # Skip the rest of the decoding + # Values were updated before this for loop + continue + # Append next token to all tokens next_token_texts = [] left = 0 @@ -1823,7 +2186,7 @@ class FlashCausalLM(Model): # Shard generations # All generations will be appended in the rust sharded client - if i % self.world_size == self.rank: + if request.id % self.world_size == self.rank: if stop: # Decode generated tokens output_text, _, _ = self.decode_token( @@ -1844,31 +2207,6 @@ class FlashCausalLM(Model): else: generated_text = None - # Prefill - if prefill and request.prefill_logprobs: - out_start_index = batch.prefill_cu_outlens[i] - out_end_index = batch.prefill_cu_outlens[i + 1] - - # Remove generated token to only have prefill and add nan for first prompt token - request_prefill_logprobs = ( - [float("nan")] * (len(prefix_ids) + 1) - ) + prefill_logprobs[out_start_index : out_end_index - 1] - prefill_token_ids = all_input_ids[:-1] - prefill_texts = self.tokenizer.batch_decode( - prefix_ids + prefill_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, - ) - - prefill_tokens = Tokens( - prefix_ids + prefill_token_ids, - request_prefill_logprobs, - prefill_texts, - is_special=[], - ) - else: - prefill_tokens = None - if top_n_tokens > 0: all_top_tokens = [] for top_token_ids, top_token_logprobs in zip( @@ -1896,7 +2234,7 @@ class FlashCausalLM(Model): generation = Generation( request.id, - prefill_tokens, + batch.prefill_tokens[i], Tokens( _next_token_ids, _next_token_logprobs, @@ -1917,9 +2255,13 @@ class FlashCausalLM(Model): ) # Update values - batch.postfix_lengths[i] = input_length + n_accepted_ids - if batch.postfix_lengths[i] > batch.max_seqlen: - batch.max_seqlen = batch.postfix_lengths[i] + current_postfix_length = postfix_length + n_accepted_ids + batch.max_postfix_length = max( + batch.max_postfix_length, current_postfix_length + ) + batch.postfix_lengths[i] = current_postfix_length + current_length = prefix_length + current_postfix_length + batch.max_current_length = max(batch.max_current_length, current_length) batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids @@ -1930,9 +2272,13 @@ class FlashCausalLM(Model): decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) - batch.prefill_cu_outlens = None - batch.prefill_head_indices = None - batch.prefill_next_token_indices = None + if prefill and finished_prefilling: + # We do not need prefill tensors anymore + batch.cu_seqlen_prefill = None + batch.prefill_cache_indices = None + batch.prefill_cu_outlens = None + batch.prefill_head_indices = None + batch.prefill_next_token_indices = None forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index d4e7cca7..ed9ae989 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -74,6 +74,14 @@ class Tokens: def __len__(self): return len(self.token_ids) + def __add__(self, other: "Tokens") -> "Tokens": + return Tokens( + self.token_ids + other.token_ids, + self.logprobs + other.logprobs, + self.texts + other.texts, + self.is_special + other.is_special, + ) + @dataclass class Generation: