From 9ac7b7bc521495b3b6335240d9b3311c79a47c7f Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 12 Jun 2024 11:50:31 +0200 Subject: [PATCH] remove slots from grpc --- benchmark/src/generation.rs | 1 - proto/v3/generate.proto | 4 -- router/client/src/v3/client.rs | 1 - router/client/src/v3/sharded_client.rs | 1 - router/src/infer/v3/block_allocator.rs | 44 +++++-------------- router/src/infer/v3/queue.rs | 18 ++++---- router/src/infer/v3/scheduler.rs | 15 +++---- .../models/flash_causal_lm.py | 43 ++++++++++-------- 8 files changed, 52 insertions(+), 75 deletions(-) diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index b82d23ba..e5fbdca4 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -156,7 +156,6 @@ async fn prefill( }), top_n_tokens: top_n_tokens.unwrap_or(0), blocks: vec![], - slots: vec![], }) .collect(); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 8138e4fb..c6e02034 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -132,8 +132,6 @@ message Request { uint32 top_n_tokens = 7; /// Paged attention blocks repeated uint32 blocks = 9; - /// Paged attention slots - repeated uint32 slots = 10; } message Batch { @@ -208,8 +206,6 @@ message KeptRequest { uint64 id = 1; /// Paged attention blocks repeated uint32 blocks = 2; - /// Paged attention slots - repeated uint32 slots = 3; } /// kept_requests + terminated_request_ids might not cover all requests from the diff --git a/router/client/src/v3/client.rs b/router/client/src/v3/client.rs index 1f8070ca..03efd4f5 100644 --- a/router/client/src/v3/client.rs +++ b/router/client/src/v3/client.rs @@ -157,7 +157,6 @@ impl Client { truncate, // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], - slots: vec![], // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/router/client/src/v3/sharded_client.rs b/router/client/src/v3/sharded_client.rs index 3f11e101..1f9ec3ad 100644 --- a/router/client/src/v3/sharded_client.rs +++ b/router/client/src/v3/sharded_client.rs @@ -250,7 +250,6 @@ impl Health for ShardedClient { top_n_tokens: 0, // Block 0 is reserved for health checks blocks: vec![0], - slots: (0..16).collect(), }; let batch = Batch { id: u64::MAX, diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index 18480dbb..e19450e3 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,11 +1,12 @@ +use std::cmp::min; use std::fmt::Formatter; use std::sync::{Arc, Mutex, TryLockError}; use thiserror::Error; #[derive(Clone)] pub(crate) struct BlockAllocation { + block_size: usize, allocated_blocks: Vec, - allocated_slots: Vec, required_blocks: usize, required_slots: usize, block_allocator: BlockAllocator, @@ -13,25 +14,20 @@ pub(crate) struct BlockAllocation { impl BlockAllocation { pub(crate) fn len(&self) -> usize { - self.allocated_slots.len() + self.allocated_blocks.len() * self.block_size } pub(crate) fn blocks(&self) -> &[u32] { &self.allocated_blocks } - pub(crate) fn slots(&self) -> &[u32] { - &self.allocated_slots - } - /// Extend an allocation by adding a new block /// If the allocation length > window size, repeats blocks and slots to cover the /// whole `required_blocks` and `required_slots` pub(crate) fn extend(&mut self) -> Result<(), AllocationError> { - let (block, slots) = self.block_allocator.allocate_block()?; + let block = self.block_allocator.allocate_block()?; // Add block and slots to current allocation self.allocated_blocks.push(block); - self.allocated_slots.extend(slots); if let Some(window_size) = self.block_allocator.window_size { // if we have more slots than the window size, @@ -41,8 +37,6 @@ impl BlockAllocation { let repeats = (self.required_slots + window_size - 1) / window_size; self.allocated_blocks = self.allocated_blocks.repeat(repeats); self.allocated_blocks.truncate(self.required_blocks); - self.allocated_slots = self.allocated_slots.repeat(repeats); - self.allocated_slots.truncate(self.required_slots); } } @@ -62,7 +56,6 @@ impl std::fmt::Debug for BlockAllocation { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("BlockAllocation") .field("allocated_blocks", &self.allocated_blocks.len()) - .field("allocated_slots", &self.allocated_slots.len()) .field("required_blocks", &self.required_blocks) .field("required_slots", &self.required_slots) .field("block_allocator", &self.block_allocator) @@ -94,30 +87,29 @@ impl BlockAllocator { } } - fn allocate_block(&self) -> Result<(u32, Vec), AllocationError> { + fn allocate_block(&self) -> Result { let mut free_blocks = self.free_blocks.lock().expect("Lock could not be acquired"); if free_blocks.is_empty() { return Err(AllocationError::NotEnoughPages); } - let block_id = free_blocks.pop().unwrap(); - let slots = ((block_id * self.block_size)..((block_id + 1) * self.block_size)).collect(); - Ok((block_id, slots)) + Ok(free_blocks.pop().unwrap()) } /// For prompt tokens, we allocate enough blocks to cover all tokens - /// For decode tokens, we allocate block by block + /// For decode tokens, we allocate min(decode_blocks, 16) blocks /// - /// If prompt tokens + min(decode_tokens, block_size) > window size, we repeat blocks and slots + /// If allocation > window size, we repeat blocks and slots pub(crate) fn block_allocation( &self, prompt_tokens: u32, decode_tokens: u32, ) -> Result { let required_prompt_blocks = (prompt_tokens + self.block_size - 1) / self.block_size; - // prompt blocks + a single block for decode - let required_blocks = required_prompt_blocks + 1; + // prompt blocks + 16 blocks for decode + let decode_blocks = (decode_tokens + self.block_size - 1) / self.block_size; + let required_blocks = required_prompt_blocks + min(decode_blocks, 16); let required_slots = required_blocks * self.block_size; // Slots and blocks required for the whole request @@ -164,21 +156,9 @@ impl BlockAllocator { allocated_blocks }; - let mut allocated_slots = - Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats); - - 'slots: for block_id in allocated_blocks.iter() { - for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { - allocated_slots.push(s); - if allocated_slots.len() > total_slots { - break 'slots; - } - } - } - Ok(BlockAllocation { + block_size: self.block_size as usize, allocated_blocks, - allocated_slots, required_blocks: total_required_blocks, required_slots: total_slots, block_allocator: self.clone(), diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index d8085800..43d2bdd8 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -224,6 +224,11 @@ impl State { } } + // Check if max_size == 0 + if max_size == Some(0) { + return None; + } + // Pad prefill_token_budget to be a multiple of block size let prefill_token_budget = ((prefill_token_budget + self.block_size - 1) / self.block_size) * self.block_size; @@ -312,14 +317,10 @@ impl State { // Update entry entry.temp_span = Some(entry_batch_span); - let (blocks, slots) = match &block_allocation { - None => (Vec::new(), Vec::new()), - Some(block_allocation) => ( - block_allocation.blocks().to_vec(), - block_allocation.slots().to_vec(), - ), - }; - + let blocks = block_allocation + .as_ref() + .map(|block_allocation| block_allocation.blocks().to_vec()) + .unwrap_or_default(); entry.block_allocation = block_allocation; batch_requests.push(Request { @@ -338,7 +339,6 @@ impl State { )), top_n_tokens: entry.request.top_n_tokens, blocks, - slots, }); // Set batch_time entry.batch_time = Some(Instant::now()); diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 6e5ffa7e..50b33951 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -164,7 +164,7 @@ pub(crate) async fn batching_task( }; let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = max_batch_size.map(|max_size| max_size - batch_size as usize); + let max_size = max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue @@ -382,16 +382,15 @@ async fn filter_batch( let updated_requests = entries .iter() .map(|(request_id, entry)| { - let (blocks, slots) = entry + let blocks = entry .block_allocation .as_ref() - .map(|alloc| (alloc.blocks().to_vec(), alloc.slots().to_vec())) + .map(|alloc| alloc.blocks().to_vec()) .unwrap_or_default(); KeptRequest { id: *request_id, blocks, - slots, } }) .collect(); @@ -991,10 +990,10 @@ mod tests { content: "You are a friendly chatbot who always responds in the style of a pirate" .to_string(), }] - .iter() - .chain(&example_chat) - .cloned() - .collect::>(); + .iter() + .chain(&example_chat) + .cloned() + .collect::>(); let test_default_templates = vec![ ChatTemplateTestItem { diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e8fd8b16..1bf9b7a5 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -133,9 +133,12 @@ class FlashCausalLMBatch(Batch): batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) max_truncation = max(max_truncation, r.truncate) + logger.error(batch_inputs) + batch_tokenized_inputs = tokenizer( batch_inputs, truncation=True, max_length=max_truncation )["input_ids"] + logger.error(batch_tokenized_inputs) return batch_tokenized_inputs @classmethod @@ -179,7 +182,7 @@ class FlashCausalLMBatch(Batch): max_blocks = 0 block_tables = [] - flat_slots = [] + flat_blocks = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -231,24 +234,18 @@ class FlashCausalLMBatch(Batch): request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) ] - request_slots = [ - s - for b in request_blocks - for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) - ] else: request_blocks = r.blocks - request_slots = r.slots block_tables.append(request_blocks) num_blocks += len(request_blocks) request_slot_indices = torch.arange( - len(flat_slots), - len(flat_slots) + input_length, + len(flat_blocks) * BLOCK_SIZE, + (len(flat_blocks) * BLOCK_SIZE) + input_length, dtype=torch.int64, ) - flat_slots.extend(request_slots) + flat_blocks.extend(request_blocks) slot_indices.append(request_slot_indices) # Create tensor to slice into the kv tensor in prefill @@ -347,7 +344,13 @@ class FlashCausalLMBatch(Batch): top_n_tokens, device=device, dtype=torch.int64 ) - slots = torch.tensor(flat_slots, dtype=torch.int64, device=device) + flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device) + + slots = ( + (flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T + + torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64) + ).flatten() + block_tables_tensor = torch.zeros( (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" ) @@ -444,8 +447,8 @@ class FlashCausalLMBatch(Batch): max_seqlen = 0 requests = [] + flat_blocks = [] block_tables = [] - flat_slots = [] all_input_ids = [] input_lengths = [] @@ -483,16 +486,13 @@ class FlashCausalLMBatch(Batch): top_n_tokens.append(self.top_n_tokens[idx]) request_block_table = request.blocks - num_blocks += len(request_block_table) block_tables.append(request_block_table) - - # List of slots allocated for this request - request_slots = request.slots + flat_blocks.extend(request_block_table) # Index - slot_indices.append(len(flat_slots) + request_input_length - 1) - flat_slots.extend(request_slots) + slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1) + num_blocks += len(request_block_table) max_blocks = max(max_blocks, len(request_block_table)) # Index into tensors @@ -514,11 +514,16 @@ class FlashCausalLMBatch(Batch): block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) # Allocate on GPU - slots = torch.tensor(flat_slots, dtype=torch.int64, device=device) slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) # Move to GPU block_tables_tensor = block_tables_tensor.to(device) + flat_blocks_tensor = torch.tensor(flat_blocks, dtype=torch.int64, device=device) + + slots = ( + (flat_blocks_tensor * BLOCK_SIZE).repeat(BLOCK_SIZE, 1).T + + torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64) + ).flatten() filtered_batch = type(self)( batch_id=self.batch_id,