diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1e81e673..2836dcc8 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -120,39 +120,47 @@ class FlashCausalLMBatch(Batch): # Decoder values input_ids: torch.Tensor - position_ids: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + position_ids: Optional[torch.Tensor] speculative_ids: Optional[torch.Tensor] - # Flash Attention values - - # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill - cu_seqlen_prefill: Optional[torch.Tensor] - # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers - # as we only keep SLIDING_WINDOW values instead of the whole tensor - prefill_cache_indices: Optional[torch.Tensor] - - # Paged Attention values - # Set when creating the batch # CPU tensor of length b indicating the start of each sequence in slots - start_slots: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + start_slots: Optional[torch.Tensor] # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode - slot_indices: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + slot_indices: Optional[torch.Tensor] # list of length b of list of length s_i // block_size block_tables: List[List[int]] # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - slots: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + slots: Optional[torch.Tensor] max_postfix_length: int max_current_length: int + # Whether this batch contains at least one request that is prefilling + prefilling: bool + # Whether each request is prefilling + prefilling_mask: List[bool] + # Prefill metadata tensors to efficiently compute logprobs + # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + cu_seqlen_prefill: Optional[torch.Tensor] + # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers + # as we only keep SLIDING_WINDOW values instead of the whole tensor + prefill_cache_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_head_indices: Optional[torch.Tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_next_token_indices: Optional[torch.tensor] + # Will be set by `generate_token` and reset after each prefill forward prefill_cu_outlens: Optional[List[int]] + # Will be set by `generate_token` and reset after each prefill forward prefill_tokens: List[Optional[Tokens]] # Prefixes @@ -164,12 +172,13 @@ class FlashCausalLMBatch(Batch): # Lengths of all generations present in the batch postfix_lengths: List[int] - postfix_lengths_tensor: torch.Tensor # size [b], containing the number of blocks that can be retrieved from the cache prefix_lengths: List[int] - prefix_lengths_tensor: torch.Tensor prompt_lengths: List[int] - prompt_lengths_tensor: torch.Tensor + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + postfix_lengths_tensor: Optional[torch.Tensor] + prefix_lengths_tensor: Optional[torch.Tensor] + prompt_lengths_tensor: Optional[torch.Tensor] prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -181,7 +190,8 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor: torch.Tensor # Adapter metadata for each request - adapter_meta: AdapterBatchMetadata + # Will be set by `generate_token` and reset after each prefill forward before staying set in decode + adapter_meta: Optional[AdapterBatchMetadata] # Number of blocks in this batch num_blocks: int @@ -225,13 +235,7 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": - sliding_window = get_sliding_windows() speculate = get_speculate() - position_ids = [] - cu_seqlen_prefill = [0] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] prefix_lengths = [] postfix_lengths = [] @@ -243,24 +247,10 @@ class FlashCausalLMBatch(Batch): prefix_ids = [] requests_idx_mapping = {} - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] - next_token_chooser_parameters = [] stopping_criterias = [] top_n_tokens = [] - adapter_indices_list = [] - adapter_set = set() - - # Cumulative length - cumulative_length = 0 - cumulative_slot_tokens = 0 - prefill_out_cumulative_length = 0 - num_blocks = 0 max_postfix_length = 0 max_current_length = 0 @@ -268,7 +258,6 @@ class FlashCausalLMBatch(Batch): max_blocks = 0 block_tables = [] - slots = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -292,8 +281,6 @@ class FlashCausalLMBatch(Batch): # FIXME: speculate is not supported for context chunking at the moment assert speculate == 0 - # Commented as it's costly. - # log_master(logger.debug, "Tokenized input ids {tokenized_input}") prefix_ids.append(tokenized_input[:prefix_length]) postfix_ids = tokenized_input[prefix_length : postfix_length] # postfix_ids = tokenized_input[prefix_length:] @@ -307,15 +294,6 @@ class FlashCausalLMBatch(Batch): all_postfix_ids.append(postfix_ids) all_input_ids.append(tokenized_input) - # Position ids - request_position_ids = torch.arange( - prefix_length, prefix_length + postfix_length, dtype=torch.int32 - ) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + postfix_length) - next_token_chooser_parameters.append(r.parameters) stopping_criteria = StoppingCriteria.from_pb( @@ -325,11 +303,6 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) - ADAPTER_TO_INDEX = get_adapter_to_index() - adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append(torch.full((postfix_length,), adapter_index)) - adapter_set.add(adapter_index) - # Paged attention # Remove one as the first token des not have a past speculative_length = get_speculate() @@ -338,75 +311,21 @@ class FlashCausalLMBatch(Batch): # Tokens that need to be mapped to blocks. block_tokens = prompt_length + max_new_tokens - 1 + speculative_length - # Tokens that need to be mapped to slots. We don't need slots for the - # cached prefix (if present). - slot_tokens = postfix_length + max_new_tokens - 1 + speculative_length - # blocks and slots can be empty (for example in warmup) if not r.blocks: needed_blocks = math.ceil(block_tokens / BLOCK_SIZE) request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) ] - request_slots = [ - s - for b in request_blocks - for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) - ] else: request_blocks = r.blocks - request_slots = r.slots[ - prefix_length: #: orig_input_length + max_new_tokens + speculative_length - ] block_tables.append(request_blocks) - slots.extend(request_slots) prefix_lengths.append(prefix_length) num_blocks += len(request_blocks) - start_slots.append(cumulative_slot_tokens) - - request_slot_indices = torch.arange( - cumulative_slot_tokens, - cumulative_slot_tokens + postfix_length, - dtype=torch.int64, - ) - slot_indices.append(request_slot_indices) - - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, postfix_length - sliding_window), - cumulative_length + postfix_length, - dtype=torch.int64, - ) - prefill_cache_indices.append(request_prefill_cache_indices) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append(request_position_ids + cumulative_length) - prefill_next_token_indices.append( - prefill_out_cumulative_length + postfix_length - 1 - ) - prefill_cu_outlens.append( - prefill_out_cumulative_length + postfix_length - ) - prefill_out_cumulative_length += postfix_length - else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + postfix_length - 1], dtype=torch.int32 - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 # Update - cumulative_length += postfix_length - cumulative_slot_tokens += slot_tokens max_blocks = max(max_blocks, len(request_blocks)) max_postfix_length = max(max_postfix_length, postfix_length) max_current_length = max(max_current_length, prefix_length + postfix_length) @@ -415,14 +334,9 @@ class FlashCausalLMBatch(Batch): prompt_length + max_new_tokens + speculative_length, ) - adapter_indices = torch.cat(adapter_indices_list).to( - dtype=torch.int64, device=device - ) - next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( @@ -438,92 +352,37 @@ class FlashCausalLMBatch(Batch): if len(pb.requests) > 1: input_ids = np.concatenate(all_postfix_ids, dtype=np.int64) - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) else: input_ids = all_postfix_ids[0] - position_ids = position_ids[0] - slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - position_ids = position_ids.to(device) - slot_indices = slot_indices.to(device) - prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - postfix_lengths_tensor = torch.tensor( - postfix_lengths, dtype=torch.int32, device=device - ) - prompt_lengths_tensor = torch.tensor( - prompt_lengths, dtype=torch.int32, device=device - ) - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 ) - slots = torch.tensor(slots, dtype=torch.int64, device=device) - block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) - prefix_lengths_tensor = torch.tensor( - prefix_lengths, dtype=torch.int32, device=device - ) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, input_ids=input_ids, - position_ids=position_ids, - cu_seqlen_prefill=cu_seqlen_prefill, - prefill_cache_indices=prefill_cache_indices, - start_slots=start_slots, - slot_indices=slot_indices, + block_tables=block_tables, block_tables_tensor=block_tables_tensor, - slots=slots, prefix_lengths=prefix_lengths, - prefix_lengths_tensor=prefix_lengths_tensor, max_postfix_length=max_postfix_length, max_current_length=max_current_length, - prefill_head_indices=prefill_head_indices, - prefill_next_token_indices=prefill_next_token_indices, - prefill_cu_outlens=prefill_cu_outlens, + prefilling=True, + prefilling_mask=[True] * len(pb.requests), prefill_tokens=[None] * len(pb.requests), postfix_lengths=postfix_lengths, - postfix_lengths_tensor=postfix_lengths_tensor, prompt_lengths=prompt_lengths, - prompt_lengths_tensor=prompt_lengths_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -535,13 +394,22 @@ class FlashCausalLMBatch(Batch): top_n_tokens_tensor=top_n_tokens_tensor, num_blocks=num_blocks, max_blocks=max_blocks, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), speculative_ids=None, + + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids=None, + cu_seqlen_prefill=None, + prefill_cache_indices=None, + start_slots=None, + slot_indices=None, + slots=None, + prefill_head_indices=None, + prefill_next_token_indices=None, + prefill_cu_outlens=None, + prefix_lengths_tensor=None, + postfix_lengths_tensor=None, + prompt_lengths_tensor=None, + adapter_meta=None, ) @classmethod @@ -594,6 +462,7 @@ class FlashCausalLMBatch(Batch): prefix_offsets = [] read_offsets = [] + prefilling_mask = [] prefill_tokens = [] stopping_criterias = [] @@ -604,6 +473,7 @@ class FlashCausalLMBatch(Batch): max_blocks = 0 # Cumulative length cumulative_max_length = 0 + prefilling=False for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] @@ -612,6 +482,11 @@ class FlashCausalLMBatch(Batch): requests.append(self.requests[idx]) + # Prefilling + request_prefilling = self.prefilling_mask[idx] + prefilling = prefilling or request_prefilling + prefilling_mask.append(request_prefilling) + # Get length request_postfix_length = self.postfix_lengths[idx] request_prefix_length = self.prefix_lengths[idx] @@ -705,6 +580,8 @@ class FlashCausalLMBatch(Batch): slots=slots, max_postfix_length=max_postfix_length, max_current_length=max_current_length, + prefilling=prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, @@ -742,6 +619,7 @@ class FlashCausalLMBatch(Batch): requests = [] requests_idx_mapping = {} + prefilling = False num_blocks = 0 total_batch_size = 0 total_slots = 0 @@ -772,6 +650,7 @@ class FlashCausalLMBatch(Batch): ) ), ) + prefilling = prefilling or b.prefilling input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) @@ -821,6 +700,7 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states = [] stopping_criterias = [] top_n_tokens = [] + prefilling_mask = [] # Cumulative length cumulative_batch_size = 0 @@ -878,6 +758,7 @@ class FlashCausalLMBatch(Batch): start_slots.append(batch.start_slots + cumulative_slots) + prefilling_mask = prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) prefix_lengths.extend(batch.prefix_lengths) all_input_ids.extend(batch.all_input_ids) @@ -937,6 +818,8 @@ class FlashCausalLMBatch(Batch): slots=slots, max_postfix_length=max_postfix_length, max_current_length=max_current_length, + prefilling=prefilling, + prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, @@ -965,6 +848,193 @@ class FlashCausalLMBatch(Batch): ), ) + def prepare_for_prefill(self): + # Prepare values if we need to continue prefilling + # Speculation must be ignored while we prefill even with chunking + # it simplifies everything + assert self.speculative_ids is None + + sliding_window = get_sliding_windows() + position_ids = [] + cu_seqlen_prefill = [0] + start_slots = [] + slot_indices = [] + prefill_cache_indices = [] + all_prefill_logprobs = True + no_prefill_logprobs = True + prefill_head_indices = [] + prefill_next_token_indices = [] + prefill_cu_outlens = [0] + + # Cumulative length + cumulative_length = 0 + cumulative_slot_tokens = 0 + prefill_out_cumulative_length = 0 + + slots = [] + adapter_indices_list = [] + adapter_set = set() + + for i, ( + r, + prefix_length, + postfix_length, + prompt_length, + request_prefilling, + blocks + ) in enumerate( + zip( + self.requests, + self.prefix_lengths, + self.postfix_lengths, + self.prompt_lengths, + self.prefilling_mask, + self.block_tables + ) + ): + next_chunk_length = postfix_length + # Position ids + request_position_ids = torch.arange( + prefix_length, prefix_length + postfix_length, dtype=torch.int32 + ) + position_ids.append(request_position_ids) + + # Add cumulative lengths of all previous inputs + cu_seqlen_prefill.append(cumulative_length + postfix_length) + + if not r.slots: + request_slots = [ + s + for b in blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_slots = r.slots + + request_slots = request_slots[prefix_length:] + request_slot_indices = torch.arange( + cumulative_slot_tokens, + cumulative_slot_tokens + postfix_length, + dtype=torch.int64, + ) + + # Create tensor to slice into the kv tensor in prefill + if sliding_window is not None: + request_prefill_cache_indices = torch.arange( + cumulative_length + max(0, postfix_length - sliding_window), + cumulative_length + postfix_length, + dtype=torch.int64, + ) + + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs + no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs + + if prefill_logprobs: + prefill_head_indices.append( + request_position_ids + cumulative_length + ) + prefill_next_token_indices.append( + prefill_out_cumulative_length + postfix_length - 1 + ) + prefill_cu_outlens.append( + prefill_out_cumulative_length + postfix_length + ) + prefill_out_cumulative_length += postfix_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + postfix_length - 1], + dtype=torch.int32, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_cu_outlens.append(prefill_out_cumulative_length + 1) + prefill_out_cumulative_length += 1 + + + start_slots.append(cumulative_slot_tokens) + slots.extend(request_slots) + slot_indices.append(request_slot_indices) + + if sliding_window is not None: + prefill_cache_indices.append(request_prefill_cache_indices) + + ADAPTER_TO_INDEX = get_adapter_to_index() + adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append( + torch.full((next_chunk_length,), adapter_index) + ) + adapter_set.add(adapter_index) + + # Update + cumulative_length += next_chunk_length + cumulative_slot_tokens += len(request_slots) + + device = self.input_ids.device + self.start_slots = torch.tensor(start_slots, dtype=torch.int64) + + if len(self) > 1: + position_ids = torch.cat(position_ids) + slot_indices = torch.cat(slot_indices) + if sliding_window is not None: + prefill_cache_indices = torch.cat(prefill_cache_indices) + else: + position_ids = position_ids[0] + slot_indices = slot_indices[0] + if sliding_window is not None: + prefill_cache_indices = prefill_cache_indices[0] + + self.prefill_cu_outlens = prefill_cu_outlens + cu_seqlen_prefill = torch.tensor( + cu_seqlen_prefill, device=device, dtype=torch.int32 + ) + self.cu_seqlen_prefill = cu_seqlen_prefill + self.position_ids = position_ids.to(device) + self.slot_indices = slot_indices.to(device) + self.prefill_cache_indices = ( + prefill_cache_indices.to(device) if sliding_window is not None else None + ) + self.postfix_lengths_tensor = torch.tensor( + self.postfix_lengths, dtype=torch.int32, device=device + ) + + if all_prefill_logprobs: + prefill_head_indices = None + prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 + elif no_prefill_logprobs: + prefill_head_indices = cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = None + else: + prefill_head_indices = torch.tensor( + torch.cat(prefill_head_indices), dtype=torch.int64, device=device + ) + prefill_next_token_indices = torch.tensor( + prefill_next_token_indices, dtype=torch.int64, device=device + ) + + self.prefill_head_indices = prefill_head_indices + self.prefill_next_token_indices = prefill_next_token_indices + self.slots = torch.tensor(slots, dtype=torch.int64, device=device) + self.prefix_lengths_tensor = torch.tensor( + self.prefix_lengths, dtype=torch.int32, device=device + ) + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + self.adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) + def __len__(self): return len(self.requests) @@ -1596,7 +1666,10 @@ class FlashCausalLM(Model): self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]: start = time.time_ns() - prefill = batch.cu_seqlen_prefill is not None + prefill = batch.prefilling + if prefill: + batch.prepare_for_prefill() + prefill_logprobs = batch.prefill_next_token_indices is not None # Update adapter indices for speculative tokens (if present) @@ -1650,6 +1723,7 @@ class FlashCausalLM(Model): finished_prefilling = True next_chunk_lengths = [] if prefill: + next_prefilling_mask = [] # Budget in tokens for the next batch # We remove next input ids to always have enough space for at least a single decode # for the remaining requests @@ -1666,11 +1740,16 @@ class FlashCausalLM(Model): ) batch_budget -= next_chunk_length finished_prefilling = False + next_prefilling_mask.append(True) else: # Since speculation will be turned off, this is always true next_chunk_length = 1 + next_prefilling_mask.append(False) next_chunk_lengths.append(next_chunk_length) + batch.prefilling = not finished_prefilling + batch.prefilling_mask = next_prefilling_mask + # Turn off speculative if some requests are still prefilling # It makes the logic easier to follow if prefill and not finished_prefilling: @@ -1708,13 +1787,6 @@ class FlashCausalLM(Model): elif not prefill: next_position_ids = batch.position_ids - # Cumulative length - cumulative_length = 0 - - # Results - generations: List[Generation] = [] - stopped = True - # Zipped iterator iterator = zip( batch.prompt_lengths, @@ -1730,6 +1802,8 @@ class FlashCausalLM(Model): # For each member of the batch index = 0 + # Cumulative length + cumulative_length = 0 for i, ( prompt_length, prefix_length, @@ -1822,242 +1896,51 @@ class FlashCausalLM(Model): # Update values if we need to continue prefilling # This represents the `else` case of the `Update values` if above # but since this require the `next_token_ids` to be on CPU, it is better to do it here - skip_tokens = {} if prefill and not finished_prefilling: # Speculation must be ignored while we prefill even with chunking # it simplifies everything assert batch.speculative_ids is None all_postfix_ids = [] - sliding_window = get_sliding_windows() - position_ids = [] - cu_seqlen_prefill = [0] - start_slots = [] - slot_indices = [] - prefill_cache_indices = [] - all_prefill_logprobs = True - no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] - prefill_cu_outlens = [0] - - # Cumulative length - cumulative_length = 0 - cumulative_slot_tokens = 0 - prefill_out_cumulative_length = 0 - - slots = [] - adapter_indices_list = [] - for i, ( - r, + request_prefilling, next_token_id, all_input_ids, prefix_length, postfix_length, - prompt_length, next_chunk_length, ) in enumerate( zip( - batch.requests, + batch.prefilling_mask, next_token_ids, batch.all_input_ids, batch.prefix_lengths, batch.postfix_lengths, - batch.prompt_lengths, next_chunk_lengths, ) ): - continue_prefilling = prefix_length + postfix_length < prompt_length - if continue_prefilling: - skip_tokens[r.id] = True - # Update prefix length - prefix_length = prefix_length + postfix_length - batch.prefix_lengths[i] = prefix_length - - # Update postfix length - postfix_length = next_chunk_length - batch.max_postfix_length = max( - batch.max_postfix_length, postfix_length - ) - batch.postfix_lengths[i] = postfix_length - - # Potentially update max_current_length - current_length = prefix_length + postfix_length - batch.max_current_length = max( - batch.max_current_length, current_length - ) - + if request_prefilling: + next_prefix_length = prefix_length + postfix_length # Get new prompt IDs to prefill postfix_ids = all_input_ids[ - prefix_length : prefix_length + postfix_length + next_prefix_length : next_prefix_length + next_chunk_length ] - - # Position ids - request_position_ids = torch.arange( - prefix_length, prefix_length + postfix_length, dtype=torch.int32 - ) - position_ids.append(request_position_ids) - - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + postfix_length) - - request_slots = r.slots[prefix_length:] - request_slot_indices = torch.arange( - cumulative_slot_tokens, - cumulative_slot_tokens + postfix_length, - dtype=torch.int64, - ) - - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.arange( - cumulative_length + max(0, postfix_length - sliding_window), - cumulative_length + postfix_length, - dtype=torch.int64, - ) - - all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs - no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs - - if r.prefill_logprobs: - prefill_head_indices.append( - request_position_ids + cumulative_length - ) - prefill_next_token_indices.append( - prefill_out_cumulative_length + postfix_length - 1 - ) - prefill_cu_outlens.append( - prefill_out_cumulative_length + postfix_length - ) - prefill_out_cumulative_length += postfix_length - else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + postfix_length - 1], - dtype=torch.int32, - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 - else: # This request is done prefilling, the new id is the one selected the sampling method postfix_ids = [next_token_id] - # Position_ids - position_ids.append( - torch.tensor( - (prefix_length + postfix_length,), dtype=torch.int32 - ) - ) - - # Add this request token - cu_seqlen_prefill.append(cumulative_length + 1) - - request_slots = r.slots[prefix_length:] - request_slot_indices = torch.tensor( - (cumulative_slot_tokens + postfix_length,), dtype=torch.int64 - ) - - # Create tensor to slice into the kv tensor in prefill - if sliding_window is not None: - request_prefill_cache_indices = torch.tensor( - [cumulative_length], dtype=torch.int64 - ) - - prefill_head_indices.append( - torch.tensor([cumulative_length], dtype=torch.int32) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) - prefill_cu_outlens.append(prefill_out_cumulative_length + 1) - prefill_out_cumulative_length += 1 - all_postfix_ids.extend(postfix_ids) - start_slots.append(cumulative_slot_tokens) - slots.extend(request_slots) - slot_indices.append(request_slot_indices) - if sliding_window is not None: - prefill_cache_indices.append(request_prefill_cache_indices) - - ADAPTER_TO_INDEX = get_adapter_to_index() - adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append( - torch.full((next_chunk_length,), adapter_index) - ) - - # Update - cumulative_length += next_chunk_length - cumulative_slot_tokens += len(request_slots) - - device = batch.input_ids.device - batch.start_slots = torch.tensor(start_slots, dtype=torch.int64) - - if len(batch) > 1: - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) - if sliding_window is not None: - prefill_cache_indices = torch.cat(prefill_cache_indices) - else: - position_ids = position_ids[0] - slot_indices = slot_indices[0] - if sliding_window is not None: - prefill_cache_indices = prefill_cache_indices[0] - - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - batch.cu_seqlen_prefill = cu_seqlen_prefill - batch.position_ids = position_ids.to(device) - batch.slot_indices = slot_indices.to(device) - batch.prefill_cache_indices = ( - prefill_cache_indices.to(device) if sliding_window is not None else None - ) - batch.input_ids = torch.tensor( - all_postfix_ids, dtype=torch.int64, device=device - ) - batch.postfix_lengths_tensor = torch.tensor( - batch.postfix_lengths, dtype=torch.int32, device=device - ) - - if all_prefill_logprobs: - prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 - elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 - prefill_next_token_indices = None - else: - prefill_head_indices = torch.tensor( - torch.cat(prefill_head_indices), dtype=torch.int64, device=device - ) - prefill_next_token_indices = torch.tensor( - prefill_next_token_indices, dtype=torch.int64, device=device - ) - - batch.prefill_head_indices = prefill_head_indices - batch.prefill_next_token_indices = prefill_next_token_indices - batch.slots = torch.tensor(slots, dtype=torch.int64, device=device) - batch.prefix_lengths_tensor = torch.tensor( - batch.prefix_lengths, dtype=torch.int32, device=device - ) - adapter_indices = torch.cat(adapter_indices_list).to( - dtype=torch.int64, device=device - ) - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - batch.adapter_meta = AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=batch.adapter_meta.adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, + batch.input_ids = batch.input_ids.new_tensor( + all_postfix_ids, dtype=torch.int64 ) start_decode = time.time_ns() + # Results + generations: List[Generation] = [] + stopped = True + # Zipped iterator iterator = zip( batch.requests, @@ -2072,11 +1955,14 @@ class FlashCausalLM(Model): batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, + batch.prefilling_mask, accepted_ids, batch_top_token_ids, batch_top_token_logprobs, ) + # Reset max_postfix_length + batch.max_postfix_length = 0 # For each member of the batch index = 0 for i, ( @@ -2092,6 +1978,7 @@ class FlashCausalLM(Model): do_sample, seed, top_n_tokens, + request_prefilling, n_accepted_ids, top_token_ids, top_token_logprobs, @@ -2139,134 +2026,133 @@ class FlashCausalLM(Model): else: batch.prefill_tokens[i] = None - # Represent whether this request is still prefilling # If it is, the tokens we decoded should be ignored - skip_token = skip_tokens.get(request.id, False) - - if skip_token: + if request_prefilling: # Make sure that we do not stop as even though this request did not create a token, it is still # processing stopped = False - # Skip the rest of the decoding - # Values were updated before this for loop - continue + new_postfix_length = next_chunk_lengths[i] + else: + new_postfix_length = n_accepted_ids + # Append next token to all tokens + next_token_texts = [] + left = 0 - # Append next token to all tokens - next_token_texts = [] - left = 0 + if n_accepted_ids > 1: + log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - if n_accepted_ids > 1: - log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") - - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) - next_token_texts.append(next_token_text) - - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) - - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break - else: - current_stopped = False - stopped = stopped and current_stopped - - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[ - index : index + n_accepted_ids - left - ] - index += n_accepted_ids - - # Shard generations - # All generations will be appended in the rust sharded client - if request.id % self.world_size == self.rank: - if stop: - # Decode generated tokens - output_text, _, _ = self.decode_token( + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( all_input_ids, - prefix_offset=len(all_input_ids) - - stopping_criteria.current_tokens - - 1, - read_offset=len(all_input_ids) - - stopping_criteria.current_tokens, - skip_special_tokens=True, + prefix_offset, + read_offset, ) - generated_text = GeneratedText( - output_text, - stopping_criteria.current_tokens, - reason, - seed if do_sample else None, + next_token_texts.append(next_token_text) + + stop, reason = stopping_criteria( + next_token_id, + next_token_text, ) - else: - generated_text = None - if top_n_tokens > 0: - all_top_tokens = [] - for top_token_ids, top_token_logprobs in zip( - top_token_ids, top_token_logprobs - ): - toptoken_texts = self.tokenizer.batch_decode( - top_token_ids, - clean_up_tokenization_spaces=False, - skip_special_tokens=False, + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[ + index : index + n_accepted_ids - left + ] + index += n_accepted_ids + + # Shard generations + # All generations will be appended in the rust sharded client + if request.id % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids, + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, ) - special_toptokens = [ - token_id in self.all_special_ids - for token_id in top_token_ids - ] - top_tokens = Tokens( - top_token_ids, - top_token_logprobs, - toptoken_texts, - special_toptokens, + generated_text = GeneratedText( + output_text, + stopping_criteria.current_tokens, + reason, + seed if do_sample else None, ) - all_top_tokens.append(top_tokens) - top_tokens = all_top_tokens - else: - top_tokens = None + else: + generated_text = None - generation = Generation( - request.id, - batch.prefill_tokens[i], - Tokens( - _next_token_ids, - _next_token_logprobs, - next_token_texts, - [nid in self.all_special_ids for nid in _next_token_ids], - ), - generated_text, - top_tokens, - ) + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None - generations.append(generation) + generation = Generation( + request.id, + batch.prefill_tokens[i], + Tokens( + _next_token_ids, + _next_token_logprobs, + next_token_texts, + [nid in self.all_special_ids for nid in _next_token_ids], + ), + generated_text, + top_tokens, + ) - # accept each new token for this specific request since we may - # have more than one new token per request with speculative decoding - for next_token_id in _next_token_ids: - batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single(i, next_token_id) - ) + generations.append(generation) + + # accept each new token for this specific request since we may + # have more than one new token per request with speculative decoding + for next_token_id in _next_token_ids: + batch.next_token_chooser = ( + batch.next_token_chooser.advance_grammar_single(i, next_token_id) + ) # Update values - current_postfix_length = postfix_length + n_accepted_ids + current_prefix_length = prefix_length + postfix_length + batch.prefix_lengths[i] = current_prefix_length + current_postfix_length = new_postfix_length batch.max_postfix_length = max( batch.max_postfix_length, current_postfix_length ) batch.postfix_lengths[i] = current_postfix_length - current_length = prefix_length + current_postfix_length + current_length = current_prefix_length + current_postfix_length batch.max_current_length = max(batch.max_current_length, current_length) + batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids diff --git a/server/text_generation_server/utils/segments.py b/server/text_generation_server/utils/segments.py index f5961102..b3f92369 100644 --- a/server/text_generation_server/utils/segments.py +++ b/server/text_generation_server/utils/segments.py @@ -7,6 +7,7 @@ from typing import List, Tuple, Union import torch +# FIXME: this should be optimized def find_segments( adapter_indices: Union[torch.Tensor, List[int]] ) -> Tuple[List[int], List[int]]: