From 18e77a5cc7aeab1784c5c8d6f4cb2b6b8d044078 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 5 Jun 2024 15:28:10 +0200 Subject: [PATCH] wip --- proto/v3/generate.proto | 24 ++++ router/src/infer/v3/queue.rs | 3 + router/src/infer/v3/scheduler.rs | 35 +++++- router/src/lib.rs | 3 + .../models/causal_lm.py | 1 + .../models/flash_causal_lm.py | 103 +++++++----------- .../models/idefics_causal_lm.py | 1 + server/text_generation_server/models/mamba.py | 1 + .../models/seq2seq_lm.py | 1 + server/text_generation_server/models/types.py | 2 + .../models/vlm_causal_lm.py | 4 +- 11 files changed, 112 insertions(+), 66 deletions(-) diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 01cc43fd..d57fbbad 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -17,6 +17,8 @@ service TextGenerationService { rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Decode token for a list of prefilled batches rpc Decode (DecodeRequest) returns (DecodeResponse); + /// Update batch + rpc Update(UpdateRequest) returns (UpdateResponse); /// Health check rpc Health (HealthRequest) returns (HealthResponse); } @@ -198,6 +200,8 @@ message Generation { optional GeneratedText generated_text = 4; /// Top tokens repeated Tokens top_tokens = 5; + /// Current length of the request: prompt tokens + number of generated tokens until this point + uint32 current_length = 6; } message FilterBatchRequest { @@ -251,6 +255,26 @@ message DecodeResponse { optional uint64 concat_ns = 6; } +message ExtendedRequest { + /// Request ID + uint64 request_id = 1; + /// Paged attention blocks to add + repeated uint32 blocks = 2; + /// Paged attention slots to add + repeated uint32 slots = 3; +} + +message UpdateRequest { + /// Batch ID + uint64 batch_id = 1; + /// Requests to update + repeated ExtendedRequest extend_requests = 2; + /// Requests to terminate + repeated uint64 terminated_request_ids = 3; +} + +message UpdateResponse {} + message WarmupRequest { /// Batch to warmup on Batch batch = 1; diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 0b66142a..15226794 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -33,6 +33,8 @@ pub(crate) struct Entry { pub batch_time: Option, /// Block Allocation pub block_allocation: Option, + /// Current length (in tokens) of the request (prompt tokens + generated_tokens) + pub current_length: u32 } /// Request Queue @@ -498,6 +500,7 @@ mod tests { queue_time: Instant::now(), batch_time: None, block_allocation: None, + current_length: 0, }; (entry, receiver_tx) } diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index ad03dd83..bf52e69f 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -88,6 +88,7 @@ impl Scheduler for SchedulerV3 { queue_time: Instant::now(), batch_time: None, block_allocation: None, + current_length: input_length, }); // Notify the background task that we have a new entry in the queue that needs @@ -287,6 +288,8 @@ async fn decode( // Send generated tokens and filter stopped entries filter_send_generations(generations, entries); + filter_update_allocations(client, entries).await; + // Filter next batch and remove requests that were stopped let next_batch = filter_batch(client, next_batch, entries).await; @@ -355,8 +358,9 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap, entries: &mut IntMap) { + // let mut extend_entries = Vec::with_capacity(entries.len()); + // let mut finish_entries = Vec::with_capacity(entries.len()); + + // for (request_id, entry) in entries.into_iter() { + // tracing::info!("Allocation {}; Current Length: {}", entry.block_allocation.as_ref().unwrap().allocated_tokens, entry.current_length); + // + // if let Some(block_allocation) = &mut entry.block_allocation { + // tracing::info!("Allocation {:?}", block_allocation); + // + // if entry.current_length > block_allocation.allocated_tokens { + // // We need to add new blocks to this entry + // let remaining_tokens = block_allocation.total_tokens - entry.current_length; + // match block_allocation.extend(remaining_tokens).await { + // true => { + // + // }, + // false => { + // + // } + // } + // } + // } + // } +} + /// Send responses through the `entry` response channel fn send_responses( generation: Generation, diff --git a/router/src/lib.rs b/router/src/lib.rs index b6902c49..52c5aa46 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1085,6 +1085,8 @@ pub(crate) enum FinishReason { EndOfSequenceToken, #[schema(rename = "stop_sequence")] StopSequence, + #[schema(rename = "out_of_pages")] + OutOfPages } impl std::fmt::Display for FinishReason { @@ -1093,6 +1095,7 @@ impl std::fmt::Display for FinishReason { FinishReason::Length => write!(f, "length"), FinishReason::EndOfSequenceToken => write!(f, "eos_token"), FinishReason::StopSequence => write!(f, "stop_sequence"), + FinishReason::OutOfPages => write!(f, "out_of_pages"), } } } diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index e896c831..2fe0f56e 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -746,6 +746,7 @@ class CausalLM(Model): ), generated_text, top_tokens, + new_input_length ) generations.append(generation) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d16d3710..da5fa9db 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -79,8 +79,6 @@ class FlashCausalLMBatch(Batch): # Paged Attention values # Set when creating the batch - # CPU tensor of length b indicating the start of each sequence in slots - start_slots: torch.Tensor # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode slot_indices: torch.Tensor @@ -88,8 +86,10 @@ class FlashCausalLMBatch(Batch): block_tables: List[List[int]] # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: torch.Tensor + # list of length b of list of length s_i + slots: List[List[int]] # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - slots: torch.Tensor + slots_tensor: torch.Tensor max_seqlen: int @@ -154,7 +154,6 @@ class FlashCausalLMBatch(Batch): sliding_window = get_sliding_windows() position_ids = [] cu_seqlen_prefill = [0] - start_slots = [] slot_indices = [] prefill_cache_indices = [] @@ -176,7 +175,6 @@ class FlashCausalLMBatch(Batch): # Cumulative length cumulative_length = 0 - cumulative_max_length = 0 prefill_out_cumulative_length = 0 num_blocks = 0 @@ -186,6 +184,7 @@ class FlashCausalLMBatch(Batch): block_tables = [] slots = [] + flat_slots = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -204,6 +203,9 @@ class FlashCausalLMBatch(Batch): input_length = len(tokenized_input) input_lengths.append(input_length) + speculative_length = get_speculate() + speculative_length = 0 if speculative_length is None else speculative_length + prefix_offsets.append(input_length - 5) read_offsets.append(input_length) @@ -226,13 +228,10 @@ class FlashCausalLMBatch(Batch): top_n_tokens.append(r.top_n_tokens) # Paged attention - # Remove one as the first token des not have a past - speculative_length = get_speculate() - speculative_length = 0 if speculative_length is None else speculative_length - total_tokens = input_length + max_new_tokens - 1 + speculative_length - # blocks and slots can be empty (for example in warmup) if not r.blocks: + # Remove one as the first token des not have a past + total_tokens = input_length + max_new_tokens - 1 + speculative_length needed_blocks = math.ceil(total_tokens / BLOCK_SIZE) request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) @@ -247,15 +246,15 @@ class FlashCausalLMBatch(Batch): request_slots = r.slots block_tables.append(request_blocks) - slots.extend(request_slots[:total_tokens]) num_blocks += len(request_blocks) - start_slots.append(cumulative_max_length) request_slot_indices = torch.arange( - cumulative_max_length, - cumulative_max_length + input_length, + len(flat_slots), + len(flat_slots) + input_length, dtype=torch.int64, ) + slots.append(request_slots) + flat_slots.extend(request_slots) slot_indices.append(request_slot_indices) # Create tensor to slice into the kv tensor in prefill @@ -289,7 +288,6 @@ class FlashCausalLMBatch(Batch): # Update cumulative_length += input_length - cumulative_max_length += total_tokens max_seqlen = max(max_seqlen, input_length) max_blocks = max(max_blocks, len(request_blocks)) max_length = max( @@ -299,7 +297,6 @@ class FlashCausalLMBatch(Batch): next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) # Padded all_input_ids_tensor all_input_ids_tensor = np.zeros( @@ -356,7 +353,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens, device=device, dtype=torch.int64 ) - slots = torch.tensor(slots, dtype=torch.int64, device=device) + slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device) block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) @@ -372,11 +369,11 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, prefill_cache_indices=prefill_cache_indices, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=prefill_head_indices, prefill_next_token_indices=prefill_next_token_indices, @@ -423,18 +420,13 @@ class FlashCausalLMBatch(Batch): # Used to index into tensors indices = [] - # slots to keep after filtering - slot_filtering_indices = torch.zeros( - self.slots.shape[0], dtype=torch.bool, device=device - ) - - # Create on CPU to only move to GPU once instead of at every copy - slot_indices = torch.empty(len(request_ids), dtype=torch.int64) + slot_indices = [] max_seqlen = 0 requests = [] - start_slots = [] block_tables = [] + slots = [] + flat_slots = [] all_input_ids = [] input_lengths = [] @@ -446,8 +438,6 @@ class FlashCausalLMBatch(Batch): num_blocks = 0 max_blocks = 0 - # Cumulative length - cumulative_max_length = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] @@ -471,27 +461,17 @@ class FlashCausalLMBatch(Batch): top_n_tokens.append(self.top_n_tokens[idx]) - remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - request_block_table = self.block_tables[idx] num_blocks += len(request_block_table) block_tables.append(request_block_table) - start_slots.append(cumulative_max_length) - # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length + request_input_length - 1 + # List of slots allocated for this request + request_slots = self.slots[idx] + slots.append(request_slots) - # Set slice - slot_filtering_indices[ - self.start_slots[idx] : self.start_slots[idx] - + request_input_length - + remaining_tokens - - 1 - ] = True - - cumulative_max_length += request_input_length + remaining_tokens - 1 + # Index + slot_indices.append(len(flat_slots) + request_input_length - 1) + flat_slots.extend(request_slots) max_blocks = max(max_blocks, len(request_block_table)) @@ -501,17 +481,15 @@ class FlashCausalLMBatch(Batch): all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( self.speculative_ids[indices] if self.speculative_ids is not None else None ) - start_slots = torch.tensor(start_slots, dtype=torch.int64) - - # Move to GPU now that we have the whole tensor - slot_indices = slot_indices.to(device) + # Allocate on GPU + slots_tensor = torch.tensor(flat_slots, dtype=torch.int64, device=device) + slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) return type(self)( batch_id=self.batch_id, @@ -521,11 +499,11 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -560,13 +538,14 @@ class FlashCausalLMBatch(Batch): max_seqlen = 0 for b in batches: total_batch_size += len(b) - total_slots += len(b.slots) + total_slots += len(b.slots_tensor) num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) max_blocks = max(max_blocks, b.max_blocks) max_seqlen = max(max_seqlen, b.max_seqlen) + # When we filter, we do not recompute this value so we do so here max_length = max( max_length, max( @@ -582,7 +561,7 @@ class FlashCausalLMBatch(Batch): input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - slots = batches[0].slots.new_empty(total_slots) + slots_tensor = batches[0].slots_tensor.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( total_batch_size @@ -597,7 +576,7 @@ class FlashCausalLMBatch(Batch): total_batch_size, ) - start_slots = [] + slots = [] block_tables = [] all_input_ids = [] @@ -627,7 +606,7 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots) + slots_end_index = cumulative_slots + len(batch.slots_tensor) # Copy tensors (GPU) input_ids[start_index:end_index] = batch.input_ids @@ -635,7 +614,7 @@ class FlashCausalLMBatch(Batch): slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor - slots[slots_start_index:slots_end_index] = batch.slots + slots_tensor[slots_start_index:slots_end_index] = batch.slots_tensor all_input_ids_tensor[ start_index:end_index, : batch.all_input_ids_tensor.shape[1] @@ -645,8 +624,7 @@ class FlashCausalLMBatch(Batch): start_index:end_index, : batch.block_tables_tensor.shape[1] ] = batch.block_tables_tensor[:, :max_blocks] - start_slots.append(batch.start_slots + cumulative_slots) - + slots.extend(batch.slots) block_tables.extend(batch.block_tables) all_input_ids.extend(batch.all_input_ids) @@ -662,9 +640,7 @@ class FlashCausalLMBatch(Batch): # Update cumulative_batch_size += len(batch) - cumulative_slots += len(batch.slots) - - start_slots = torch.concat(start_slots) + cumulative_slots += len(batch.slots_tensor) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, @@ -688,11 +664,11 @@ class FlashCausalLMBatch(Batch): position_ids=position_ids, cu_seqlen_prefill=None, prefill_cache_indices=None, - start_slots=start_slots, slot_indices=slot_indices, block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + slots_tensor=slots_tensor, max_seqlen=max_seqlen, prefill_head_indices=None, prefill_next_token_indices=None, @@ -993,7 +969,7 @@ class FlashCausalLM(Model): cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] + slots = batch.slots_tensor[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1032,7 +1008,7 @@ class FlashCausalLM(Model): cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] + slots = batch.slots_tensor[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1374,6 +1350,7 @@ class FlashCausalLM(Model): ), generated_text, top_tokens, + input_length + n_accepted_ids ) generations.append(generation) diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index f507d669..44b21899 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -829,6 +829,7 @@ class IdeficsCausalLM(Model): ), generated_text, top_tokens, + new_input_length ) generations.append(generation) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 3133a137..8182eb46 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -775,6 +775,7 @@ class Mamba(Model): ), generated_text, top_tokens, + new_input_length ) generations.append(generation) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3bd09556..74ea2dab 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -801,6 +801,7 @@ class Seq2SeqLM(Model): ), generated_text, top_tokens, + new_decoder_input_length, ) generations.append(generation) diff --git a/server/text_generation_server/models/types.py b/server/text_generation_server/models/types.py index 339b733b..1c7a157a 100644 --- a/server/text_generation_server/models/types.py +++ b/server/text_generation_server/models/types.py @@ -84,6 +84,7 @@ class Generation: generated_text: Optional[GeneratedText] # Optional for now, since it's not yet supported for every model. top_tokens: Optional[List[Tokens]] + current_length: int def to_pb(self) -> generate_pb2.Generation: return generate_pb2.Generation( @@ -100,4 +101,5 @@ class Generation: if self.top_tokens is not None else None ), + current_length=self.current_length, ) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 59a6fab1..b1ccd140 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -228,7 +228,7 @@ class VlmCausalLM(BaseFlashMistral): cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] + slots = batch.slots_tensor[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -267,7 +267,7 @@ class VlmCausalLM(BaseFlashMistral): cu_seqlen_prefill = batch.cu_seqlen_prefill kv_cache = self.kv_cache block_tables = batch.block_tables_tensor - slots = batch.slots[batch.slot_indices] + slots = batch.slots_tensor[batch.slot_indices] input_lengths = batch.input_lengths_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices