diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 33fe30a8..8933f13e 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -149,11 +149,26 @@ class FlashCausalLMBatch(Batch): max_seqlen: int - # Prefill metadata tensors to efficiently compute logprobs + # Prefill metadata tensors prefill_head_indices: Optional[torch.Tensor] - prefill_next_token_indices: Optional[torch.tensor] + prefill_next_token_indices: Optional[torch.Tensor] prefill_cu_outlens: Optional[List[int]] + # Whether at least one request is prefilling/chunking + # == any(prefilling_mask) + prefilling: bool + # For each request, whether they are still prefilling/chunking + prefilling_mask: List[bool] + # For each request, whether the model output should be used or discarded + # If we are chunking, we don't care about the output as it might be different + # from the token in the prompt + use_output_token: List[bool] + + # If the request is decoding, `next_chunk_length = 1` + # `None if not batch.prefilling` + next_chunk_lengths: Optional[List[int]] + next_chunk_lengths_tensor: Optional[torch.Tensor] + # Prefixes prefix_ids: List[List[int]] @@ -232,11 +247,14 @@ class FlashCausalLMBatch(Batch): prefix_ids = [] requests_idx_mapping = {} + chunking = False all_prefill_logprobs = True no_prefill_logprobs = True prefill_head_indices = [] prefill_next_token_indices = [] prefill_cu_outlens = [0] + next_chunk_lengths = [] + use_output_token = [] next_token_chooser_parameters = [] stopping_criterias = [] @@ -276,6 +294,7 @@ class FlashCausalLMBatch(Batch): assert prefix_len > 0 prefix_len -= 1 + # Commented as it's costly. # log_master(logger.debug, "Tokenized input ids {tokenized_input}") prefix_ids.append(tokenized_input[:prefix_len]) @@ -284,9 +303,18 @@ class FlashCausalLMBatch(Batch): input_length = len(tokenized_input) input_lengths.append(input_length) + if True: + # This request only requires one prefill and no chunking + use_output_token.append(True) + next_chunk_lengths.append(1) + else: + chunking = True + raise NotImplementedError + prefix_offsets.append(input_length - 5) read_offsets.append(input_length) + # FIXME: use all input tokens not just postfix ones all_input_ids.append(tokenized_input) # Position ids @@ -357,6 +385,7 @@ class FlashCausalLMBatch(Batch): # Create tensor to slice into the kv tensor in prefill if sliding_window is not None: + raise NotImplementedError request_prefill_cache_indices = torch.arange( cumulative_length + max(0, input_length - sliding_window), cumulative_length + input_length, @@ -368,6 +397,7 @@ class FlashCausalLMBatch(Batch): no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs if r.prefill_logprobs: + raise NotImplementedError prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( prefill_out_cumulative_length + input_length - 1 @@ -445,6 +475,12 @@ class FlashCausalLMBatch(Batch): adapter_segments, dtype=torch.int32, device=device ) + if chunking: + next_chunk_lengths_tensor = torch.tensor(next_chunk_lengths, dtype=torch.int64, device=device) + else: + next_chunk_lengths = None + next_chunk_lengths_tensor = None + if all_prefill_logprobs: prefill_head_indices = None prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 @@ -491,6 +527,11 @@ class FlashCausalLMBatch(Batch): prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, prefill_cu_outlens=prefill_cu_outlens, + prefilling=True, + prefilling_mask=[True] * pb.requests.len(), + use_output_token=use_output_token, + next_chunk_lengths=next_chunk_lengths, + next_chunk_lengths_tensor=next_chunk_lengths_tensor, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, prefix_offsets=prefix_offsets, @@ -1426,7 +1467,7 @@ class FlashCausalLM(Model): max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices - if cu_seqlen_prefill is None and self.max_past() is not None: + if not batch.prefilling and self.max_past() is not None: # In decode, not prefill, we're actually overwriting the KV-cache # in a circular buffer mode. # This makes sure the max_s for the decode pass is correct. @@ -1440,7 +1481,7 @@ class FlashCausalLM(Model): else: cuda_graph = None - if cu_seqlen_prefill is not None or cuda_graph is None: + if batch.prefilling or cuda_graph is None: if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, @@ -1475,6 +1516,7 @@ class FlashCausalLM(Model): adapter_data=adapter_data, ) if batch.prefill_cache_indices is not None: + raise NotImplementedError batch.prefill_cache_indices = None return logits, speculative_logits @@ -1528,7 +1570,6 @@ class FlashCausalLM(Model): self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: start = time.time_ns() - prefill = batch.cu_seqlen_prefill is not None prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) @@ -1554,13 +1595,13 @@ class FlashCausalLM(Model): adapter_data = AdapterBatchData.from_meta( adapter_meta, self.layer_to_adapter_weights, - prefill, + batch.prefilling, batch.prefill_head_indices, ) out, speculative_logits = self.forward(batch, adapter_data) - if prefill: + if batch.prefilling: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) @@ -1597,22 +1638,31 @@ class FlashCausalLM(Model): batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids ) - if prefill: + if batch.prefilling: if len(batch) > 1 and prefill_logprobs: # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - next_position_ids = batch.position_ids.new_empty(len(batch)) - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - # We do not need cu_seqlen_prefill anymore - batch.cu_seqlen_prefill = None + if batch.next_chunk_lengths is None: + # We are done prefilling after this forward + next_position_ids = batch.position_ids.new_empty(len(batch)) + # [BATCH_SIZE] + # Last slot for each request, will be incremented later + batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] + else: + # We still have prefill chunks to go through + next_forward_size = sum(batch.next_chunk_lengths) + next_position_ids = batch.position_ids.new_empty(next_forward_size) + batch.slot_indices = batch.slot_indices.new_empty(next_forward_size) + batch.cu_seqlen_prefill[1:] = torch.cumsum(batch.next_chunk_lengths_tensor, dim=0) else: prefill_logprobs = None next_position_ids = batch.position_ids # Cumulative length cumulative_length = 0 + cumulative_chunk_lengths = 0 # Results generations: List[Generation] = [] @@ -1625,21 +1675,32 @@ class FlashCausalLM(Model): # one, we need to first do a GPU <-> CPU sync # It is faster if we delay this sync for the maximum amount of time - # For each member of the batch index = 0 + # For each member of the batch for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator): # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length - if prefill: + if batch.prefilling: + if batch.next_chunk_lengths is not None: + next_chunk_length = batch.next_chunk_lengths[i] + else: + next_chunk_length = 1 + # Indexing metadata out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] out_length = out_end_index - out_start_index + position_start_index = + # Initialize position_ids # In decode, we do not need this as we can just increment position ids + + + next_position_ids + next_position_ids[i] = batch.position_ids[end_index - 1] # Initialize adapter indices @@ -1651,6 +1712,7 @@ class FlashCausalLM(Model): # Used to gather prefill logprobs # Copy batch.input_ids to prefill_token_indices if prefill_logprobs: + raise NotImplementedError if len(batch) > 1: prefill_tokens_indices[out_start_index : out_end_index - 1] = ( batch.input_ids[start_index + 1 : start_index + out_length] @@ -1668,14 +1730,23 @@ class FlashCausalLM(Model): cumulative_length += input_length # Update values - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] - batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.input_lengths_tensor += accepted_ids - batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices + if batch.next_prefilling_chunk_lengths is None: + # We are done prefilling + batch.prefilling = False + batch.next_prefilling_chunk_lengths = None + batch.next_prefilling_chunk_lengths_tensor = None + # We do not need cu_seqlen_prefill anymore + batch.cu_seqlen_prefill = None - if prefill: + if not batch.prefilling: + batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.speculative_ids = speculative_ids + batch.position_ids = next_position_ids + accepted_ids + batch.input_lengths_tensor += accepted_ids + batch.slot_indices += accepted_ids + batch.adapter_meta.adapter_indices = next_adapter_indices + + if batch.prefilling: # adjust segment lengths to account for all request lengths being 1 during decoding adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices) batch.adapter_meta.adapter_segments = torch.tensor( @@ -1684,7 +1755,8 @@ class FlashCausalLM(Model): device=batch.adapter_meta.adapter_segments.device, ) - if prefill and prefill_logprobs: + if batch.prefilling and prefill_logprobs: + raise NotImplementedError # Get prefill logprobs prefill_logprobs_tensor = torch.log_softmax(out, -1) prefill_logprobs = torch.gather( @@ -1795,7 +1867,8 @@ class FlashCausalLM(Model): generated_text = None # Prefill - if prefill and request.prefill_logprobs: + if batch.prefilling and request.prefill_logprobs: + raise NotImplementedError out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1]