diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2836dcc8..8ee9d184 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -16,7 +16,7 @@ from transformers import ( AutoTokenizer, GenerationConfig, ) -from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict +from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict, Union from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -119,7 +119,9 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping: Dict[int, int] # Decoder values - input_ids: torch.Tensor + # Can be a list for easy filtering + # If `input_ids` is a list, it needs to be materialized to a tensor first + input_ids: Union[torch.Tensor, List[List[int]]] # Will be set by `generate_token` and reset after each prefill forward before staying set in decode position_ids: Optional[torch.Tensor] speculative_ids: Optional[torch.Tensor] @@ -178,7 +180,7 @@ class FlashCausalLMBatch(Batch): # Will be set by `generate_token` and reset after each prefill forward before staying set in decode postfix_lengths_tensor: Optional[torch.Tensor] prefix_lengths_tensor: Optional[torch.Tensor] - prompt_lengths_tensor: Optional[torch.Tensor] + prompt_lengths_tensor: torch.Tensor prefix_offsets: List[Optional[int]] read_offsets: List[Optional[int]] @@ -350,12 +352,6 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor, dtype=torch.int64, device=device ) - if len(pb.requests) > 1: - input_ids = np.concatenate(all_postfix_ids, dtype=np.int64) - else: - input_ids = all_postfix_ids[0] - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - top_n_tokens_tensor = torch.tensor( top_n_tokens, device=device, dtype=torch.int64 ) @@ -366,12 +362,15 @@ class FlashCausalLMBatch(Batch): for i, request_blocks in enumerate(block_tables): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) block_tables_tensor = block_tables_tensor.to(device) + prompt_lengths_tensor = torch.tensor( + prompt_lengths, dtype=torch.int32, device=device + ) return cls( batch_id=pb.id, requests=pb.requests, requests_idx_mapping=requests_idx_mapping, - input_ids=input_ids, + input_ids=all_postfix_ids, block_tables=block_tables, block_tables_tensor=block_tables_tensor, @@ -395,6 +394,7 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=None, + prompt_lengths_tensor=prompt_lengths_tensor, # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids=None, @@ -408,7 +408,6 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens=None, prefix_lengths_tensor=None, postfix_lengths_tensor=None, - prompt_lengths_tensor=None, adapter_meta=None, ) @@ -455,6 +454,7 @@ class FlashCausalLMBatch(Batch): block_tables = [] all_input_ids = [] prefix_ids = [] + input_ids = [] prompt_lengths = [] postfix_lengths = [] @@ -473,7 +473,6 @@ class FlashCausalLMBatch(Batch): max_blocks = 0 # Cumulative length cumulative_max_length = 0 - prefilling=False for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] @@ -484,9 +483,13 @@ class FlashCausalLMBatch(Batch): # Prefilling request_prefilling = self.prefilling_mask[idx] - prefilling = prefilling or request_prefilling prefilling_mask.append(request_prefilling) + # Input ids if the request was part of a prefilling batch + # If the batch was decoding we can index into the tensor directly later + if self.prefilling: + input_ids.append(self.input_ids[idx]) + # Get length request_postfix_length = self.postfix_lengths[idx] request_prefix_length = self.prefix_lengths[idx] @@ -538,32 +541,48 @@ class FlashCausalLMBatch(Batch): max_blocks = max(max_blocks, len(request_block_table)) - # Index into tensors - input_ids = self.input_ids[indices] - position_ids = self.position_ids[indices] - adapter_indices = self.adapter_meta.adapter_indices[indices] all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] - postfix_lengths_tensor = self.postfix_lengths_tensor[indices] - prompt_lengths_tensor = self.prompt_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] - prefix_lengths_tensor = self.prefix_lengths_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( self.speculative_ids[indices] if self.speculative_ids is not None else None ) + prompt_lengths_tensor = self.prompt_lengths_tensor[indices] - start_slots = torch.tensor(start_slots, dtype=torch.int64) + if self.prefilling: + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids=None + start_slots=None + slot_indices=None + slots=None + prefix_lengths_tensor=None + postfix_lengths_tensor=None + adapter_meta=None + else: + # Index into tensors + input_ids = self.input_ids[indices] + position_ids = self.position_ids[indices] + adapter_indices = self.adapter_meta.adapter_indices[indices] + postfix_lengths_tensor = self.postfix_lengths_tensor[indices] + slots = self.slots[slot_filtering_indices] + prefix_lengths_tensor = self.prefix_lengths_tensor[indices] - # Move to GPU now that we have the whole tensor - slot_indices = slot_indices.to(device) + start_slots = torch.tensor(start_slots, dtype=torch.int64) - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) - adapter_segments = torch.tensor( - adapter_segments, dtype=torch.int32, device=device - ) - # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() + # Move to GPU now that we have the whole tensor + slot_indices = slot_indices.to(device) + + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + adapter_segments = torch.tensor( + adapter_segments, dtype=torch.int32, device=device + ) + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return type(self)( batch_id=self.batch_id, @@ -580,7 +599,7 @@ class FlashCausalLMBatch(Batch): slots=slots, max_postfix_length=max_postfix_length, max_current_length=max_current_length, - prefilling=prefilling, + prefilling=self.prefilling, prefilling_mask=prefilling_mask, prefill_head_indices=None, prefill_next_token_indices=None, @@ -604,12 +623,7 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, ) @classmethod @@ -652,38 +666,51 @@ class FlashCausalLMBatch(Batch): ) prefilling = prefilling or b.prefilling - input_ids = batches[0].input_ids.new_empty(total_batch_size) - position_ids = batches[0].position_ids.new_empty(total_batch_size) - slots = batches[0].slots.new_empty(total_slots) - slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + if prefilling: + input_ids = [] + # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` + position_ids=None + start_slots=None + slots=None + slot_indices=None + prefix_lengths_tensor=None + postfix_lengths_tensor=None + adapter_meta=None + adapter_segment_builder=None + else: + input_ids = batches[0].input_ids.new_empty(total_batch_size) + position_ids = batches[0].position_ids.new_empty(total_batch_size) + start_slots = [] + slots = batches[0].slots.new_empty(total_slots) + slot_indices = batches[0].slot_indices.new_empty(total_batch_size) + postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( + total_batch_size + ) + prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty( + total_batch_size + ) + total_indices_size = sum( + b.adapter_meta.adapter_indices.shape[0] for b in batches + ) + adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( + total_indices_size + ) + adapter_segment_builder = SegmentConcatBuilder() + adapter_set = set() + prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty( total_batch_size ) - postfix_lengths_tensor = batches[0].postfix_lengths_tensor.new_empty( - total_batch_size - ) block_tables_tensor = batches[0].block_tables_tensor.new_zeros( (total_batch_size, max_blocks) ) - prefix_lengths_tensor = batches[0].prefix_lengths_tensor.new_empty( - total_batch_size - ) all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros( (total_batch_size, max_length) ) top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( total_batch_size, ) - total_indices_size = sum( - b.adapter_meta.adapter_indices.shape[0] for b in batches - ) - adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty( - total_indices_size - ) - adapter_set = set() - adapter_segment_builder = SegmentConcatBuilder() - start_slots = [] block_tables = [] prefix_lengths = [] all_input_ids = [] @@ -723,29 +750,7 @@ class FlashCausalLMBatch(Batch): slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) - input_ids[start_index:end_index] = batch.input_ids - position_ids[start_index:end_index] = batch.position_ids - slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots - prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor - postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - slots[slots_start_index:slots_end_index] = batch.slots - - # Copy over adapter indices - adapter_start_index = cumulative_adapter_indices_size - adapter_end_index = ( - cumulative_adapter_indices_size - + batch.adapter_meta.adapter_indices.shape[0] - ) - adapter_indices[adapter_start_index:adapter_end_index] = ( - batch.adapter_meta.adapter_indices - ) - cumulative_adapter_indices_size = adapter_end_index - adapter_set.update(batch.adapter_meta.adapter_set) - adapter_segment_builder.concat( - batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices - ) - all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] ] = batch.all_input_ids_tensor[:, :max_length] @@ -753,12 +758,38 @@ class FlashCausalLMBatch(Batch): block_tables_tensor[ start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] + prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor - prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor + if not prefilling: + input_ids[start_index:end_index] = batch.input_ids + position_ids[start_index:end_index] = batch.position_ids + slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots + postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor + slots[slots_start_index:slots_end_index] = batch.slots - start_slots.append(batch.start_slots + cumulative_slots) + # Copy over adapter indices + adapter_start_index = cumulative_adapter_indices_size + adapter_end_index = ( + cumulative_adapter_indices_size + + batch.adapter_meta.adapter_indices.shape[0] + ) + adapter_indices[adapter_start_index:adapter_end_index] = ( + batch.adapter_meta.adapter_indices + ) + cumulative_adapter_indices_size = adapter_end_index + adapter_set.update(batch.adapter_meta.adapter_set) + adapter_segment_builder.concat( + batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices + ) + prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor - prefilling_mask = prefilling_mask.extend(batch.prefilling_mask) + start_slots.append(batch.start_slots + cumulative_slots) + else: + if isinstance(batch.input_ids, torch.Tensor): + batch.input_ids = batch.input_ids.view(-1, 1).tolist() + input_ids.extend(batch.input_ids) + + prefilling_mask.extend(batch.prefilling_mask) block_tables.extend(batch.block_tables) prefix_lengths.extend(batch.prefix_lengths) all_input_ids.extend(batch.all_input_ids) @@ -781,7 +812,8 @@ class FlashCausalLMBatch(Batch): cumulative_batch_size += len(batch) cumulative_slots += len(batch.slots) - start_slots = torch.concat(start_slots) + if start_slots is not None: + start_slots = torch.concat(start_slots) # assert sum(len(b) for b in block_tables) == (block_tables_tensor != 0).sum() @@ -799,7 +831,14 @@ class FlashCausalLMBatch(Batch): else None ) - adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + if adapter_segment_builder is not None: + adapter_segments, adapter_segment_indices = adapter_segment_builder.build() + adapter_meta = AdapterBatchMetadata( + adapter_indices=adapter_indices, + adapter_set=adapter_set, + adapter_segments=adapter_segments, + segment_indices=adapter_segment_indices, + ) return cls( batch_id=batches[0].batch_id, @@ -840,12 +879,7 @@ class FlashCausalLMBatch(Batch): num_blocks=num_blocks, max_blocks=max_blocks, speculative_ids=speculative_ids, - adapter_meta=AdapterBatchMetadata( - adapter_indices=adapter_indices, - adapter_set=adapter_set, - adapter_segments=adapter_segments, - segment_indices=adapter_segment_indices, - ), + adapter_meta=adapter_meta, ) def prepare_for_prefill(self): @@ -973,9 +1007,16 @@ class FlashCausalLMBatch(Batch): cumulative_length += next_chunk_length cumulative_slot_tokens += len(request_slots) - device = self.input_ids.device + device = self.block_tables_tensor.device self.start_slots = torch.tensor(start_slots, dtype=torch.int64) + if isinstance(self.input_ids, list): + if len(self) > 1: + input_ids = np.concatenate(self.input_ids, dtype=np.int64) + else: + input_ids = self.input_ids[0] + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + if len(self) > 1: position_ids = torch.cat(position_ids) slot_indices = torch.cat(slot_indices) @@ -1865,7 +1906,8 @@ class FlashCausalLM(Model): batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] batch.speculative_ids = speculative_ids batch.position_ids = next_position_ids + accepted_ids - batch.postfix_lengths_tensor += accepted_ids + batch.prefix_lengths_tensor += batch.postfix_lengths_tensor + batch.postfix_lengths_tensor = accepted_ids batch.slot_indices += accepted_ids batch.adapter_meta.adapter_indices = next_adapter_indices @@ -1929,11 +1971,9 @@ class FlashCausalLM(Model): # This request is done prefilling, the new id is the one selected the sampling method postfix_ids = [next_token_id] - all_postfix_ids.extend(postfix_ids) + all_postfix_ids.append(postfix_ids) - batch.input_ids = batch.input_ids.new_tensor( - all_postfix_ids, dtype=torch.int64 - ) + batch.input_ids = all_postfix_ids start_decode = time.time_ns() @@ -2014,7 +2054,7 @@ class FlashCausalLM(Model): ) prefill_tokens = Tokens( - prefix_ids + prefill_token_ids, + prefill_token_ids, request_prefill_logprobs, prefill_texts, is_special=[],