From 682db34b6ac2af2947e2ac68aa0e45a69d4e070f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 26 Aug 2024 14:59:27 +0200 Subject: [PATCH] Handling debugger. --- backends/v3/src/backend.rs | 8 +-- backends/v3/src/block_allocator.rs | 94 ++++++++++++++++++++++++++---- backends/v3/src/queue.rs | 2 +- backends/v3/src/radix.rs | 76 ++++++++++++------------ flake.lock | 6 +- flake.nix | 2 +- launcher/src/main.rs | 58 ++++++++++++++---- router/src/lib.rs | 10 ++++ 8 files changed, 186 insertions(+), 70 deletions(-) diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index ec80d55a3..05a263705 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -43,13 +43,7 @@ impl BackendV3 { let attention: Attention = attention .parse() .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")); - let block_size = if attention == Attention::FlashDecoding { - 256 - } else if attention == Attention::FlashInfer { - 1 - } else { - 16 - }; + let block_size = attention.block_size(); let queue = Queue::new( requires_padding, diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 1cbc58684..4fea172b6 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -91,7 +91,11 @@ async fn block_allocator_task( window_size: Option, mut receiver: mpsc::UnboundedReceiver, ) { - let mut allocator = RadixAllocator::new(block_size, blocks, window_size, prefix_caching); + let mut allocator: Box = if prefix_caching { + Box::new(RadixAllocator::new(block_size, blocks, window_size)) + } else { + Box::new(SimpleAllocator::new(blocks, block_size, window_size)) + }; while let Some(cmd) = receiver.recv().await { match cmd { BlockAllocatorCommand::Free { @@ -124,12 +128,82 @@ enum BlockAllocatorCommand { }, } -// pub trait Allocator { -// fn allocate( -// &mut self, -// tokens: u32, -// prefill_tokens: Option>>, -// ) -> Option; -// -// fn free(&mut self, blocks: Vec, allocation_id: u64); -// } +pub trait Allocator { + fn allocate( + &mut self, + tokens: u32, + prefill_tokens: Option>>, + ) -> Option; + + fn free(&mut self, blocks: Vec, allocation_id: u64); +} +pub struct SimpleAllocator { + free_blocks: Vec, + block_size: u32, + window_size: Option, +} + +impl SimpleAllocator { + fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { + SimpleAllocator { + block_size, + // Block 0 is reserved for health checks + free_blocks: (1..blocks).collect(), + window_size, + } + } +} + +impl Allocator for SimpleAllocator { + fn allocate( + &mut self, + tokens: u32, + _prefill_tokens: Option>>, + ) -> Option { + // Apply window size + let (required_blocks, repeats) = { + let (tokens, repeats) = match self.window_size { + None => (tokens, 1), + Some(window_size) => { + let repeats = (tokens + window_size - 1) / window_size; + let tokens = core::cmp::min(tokens, window_size); + (tokens, repeats as usize) + } + }; + // Pad to a multiple of block size + let required_blocks = (tokens + self.block_size - 1) / self.block_size; + (required_blocks, repeats) + }; + + let tokens = tokens as usize; + if required_blocks > self.free_blocks.len() as u32 { + None + } else { + let blocks = self + .free_blocks + .split_off(self.free_blocks.len() - required_blocks as usize); + let mut slots = + Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); + + 'slots: for block_id in blocks.repeat(repeats).iter() { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() == tokens { + break 'slots; + } + } + } + Some(BlockAllocation { + allocation_id: 0, + blocks, + slots, + prefix_len: 0, + block_allocator: None, + }) + } + } + + fn free(&mut self, blocks: Vec, _allocation_id: u64) { + self.free_blocks.extend(blocks) + } +} diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index faa57c113..4958b2d47 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -333,7 +333,7 @@ impl State { break 'entry_loop; } Some(block_allocation) => { - tracing::debug!("Allocation: {block_allocation:?}"); + tracing::info!("Allocation: {block_allocation:?}"); max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); Some(block_allocation) } diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index ce79009a3..c6a9cf1e2 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -1,12 +1,10 @@ +use crate::block_allocator::{Allocator, BlockAllocation}; +use slotmap::{DefaultKey, SlotMap}; use std::{ collections::{BTreeSet, HashMap}, sync::Arc, }; -use slotmap::{DefaultKey, SlotMap}; - -use crate::block_allocator::BlockAllocation; - pub struct RadixAllocator { allocation_id: u64, @@ -21,25 +19,15 @@ pub struct RadixAllocator { // This isn't used because the prefix need to match without the windowing // mecanism. This at worst is overallocating, not necessarily being wrong. window_size: Option, - - /// Wether to actual use the radix tree for searching or not. - prefix_caching: bool, } impl RadixAllocator { - pub fn new( - block_size: u32, - n_blocks: u32, - window_size: Option, - prefix_caching: bool, - ) -> Self { - if prefix_caching { - assert_eq!( - block_size, 1, - "Radix tree allocator only works with block_size=1, was: {}", - block_size - ); - } + pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { + assert_eq!( + block_size, 1, + "Radix tree allocator only works with block_size=1, was: {}", + block_size + ); // if window_size.is_some() { // unimplemented!("Window size not supported in the prefix-caching block allocator yet"); // } @@ -52,7 +40,6 @@ impl RadixAllocator { // Block 0 is reserved for health checks. free_blocks: (1..n_blocks).collect(), window_size, - prefix_caching, } } @@ -81,24 +68,23 @@ impl RadixAllocator { } // Allocator trait -impl RadixAllocator { - pub fn allocate( +impl Allocator for RadixAllocator { + fn allocate( &mut self, tokens: u32, prefill_tokens: Option>>, ) -> Option { let mut blocks = vec![]; - let prefix_node = match (self.prefix_caching, prefill_tokens.as_ref()) { - (true, Some(prefill_tokens)) => { - let node_id = self - .cache_blocks - .find(prefill_tokens.as_slice(), &mut blocks); - // Even if this allocation fails below, we need to increase he - // refcount to ensure that the prefix that was found is not evicted. + let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { + let node_id = self + .cache_blocks + .find(prefill_tokens.as_slice(), &mut blocks); + // Even if this allocation fails below, we need to increase he + // refcount to ensure that the prefix that was found is not evicted. - node_id - } - _ => self.cache_blocks.root_id(), + node_id + } else { + self.cache_blocks.root_id() }; self.cache_blocks @@ -108,7 +94,9 @@ impl RadixAllocator { let prefix_len = blocks.len(); let suffix_len = tokens - prefix_len as u32; - match self.alloc_or_reclaim(suffix_len as usize) { + let suffix_blocks = suffix_len; + + match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), None => { self.cache_blocks @@ -127,6 +115,8 @@ impl RadixAllocator { prefill_tokens: prefill_tokens.clone(), }; + tracing::info!("Blocks {blocks:?}"); + self.allocation_id += 1; self.allocations.insert(self.allocation_id, allocation); @@ -139,7 +129,7 @@ impl RadixAllocator { }) } - pub fn free(&mut self, blocks: Vec, allocation_id: u64) { + fn free(&mut self, blocks: Vec, allocation_id: u64) { let allocation = match self.allocations.remove(&allocation_id) { Some(allocation) => allocation, None => unreachable!("Tried to free an unknown allocation."), @@ -613,7 +603,21 @@ mod tests { cache.free(allocation.blocks.clone(), allocation.allocation_id); let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); - assert_eq!(allocation.blocks, vec![1, 2, 3, 8, 9, 10, 11, 7]); + assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); + assert_eq!(allocation.prefix_len, 0); + } + + #[test] + fn allocator_block_size() { + let mut cache = RadixAllocator::new(256, 12, None, false); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![11]); + assert_eq!(allocation.slots, allocation.slots); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![11]); assert_eq!(allocation.prefix_len, 0); } diff --git a/flake.lock b/flake.lock index 1a6353f54..ad234ae40 100644 --- a/flake.lock +++ b/flake.lock @@ -835,11 +835,11 @@ ] }, "locked": { - "lastModified": 1724379657, - "narHash": "sha256-+CFDh1FUgyY7q0FiWhKJpHS7LlD3KbiqN5Z4Z+4bGmc=", + "lastModified": 1724638882, + "narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "a18034322c7703fcfe5d7352a77981ba4a936a61", + "rev": "19b70f147b9c67a759e35824b241f1ed92e46694", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 87c6b5555..96e3e3e7f 100644 --- a/flake.nix +++ b/flake.nix @@ -57,7 +57,7 @@ { devShells = with pkgs; rec { - default = pure; + default = impure; pure = mkShell { buildInputs = [ diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0d6662be7..cc1d518e3 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -8,7 +8,7 @@ use nix::unistd::Pid; use serde::Deserialize; use std::env; use std::ffi::OsString; -use std::io::{BufRead, BufReader, Lines}; +use std::io::{BufRead, BufReader}; use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::path::Path; use std::process::{Child, Command, ExitStatus, Stdio}; @@ -18,7 +18,10 @@ use std::sync::{mpsc, Arc}; use std::thread; use std::thread::sleep; use std::time::{Duration, Instant}; -use std::{fs, io}; +use std::{ + fs, io, + io::{Read, Write}, +}; use thiserror::Error; use tracing_subscriber::{filter::LevelFilter, EnvFilter}; @@ -833,6 +836,7 @@ fn shard_manager( .args(shard_args) .env_clear() .envs(envs) + .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .process_group(0) @@ -854,12 +858,13 @@ fn shard_manager( }; // Redirect STDOUT to the console + let mut pstdin = p.stdin.take().unwrap(); let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); //stdout tracing thread thread::spawn(move || { - log_lines(shard_stdout_reader.lines()); + log_lines(shard_stdout_reader); }); // We read stderr in another thread as it seems that lines() can block in some cases let (err_sender, err_receiver) = mpsc::channel(); @@ -868,6 +873,18 @@ fn shard_manager( err_sender.send(line).unwrap_or(()); } }); + // We read stdin in another thread as it seems that lines() can block in some cases + thread::spawn(move || { + let mut stdin = io::stdin(); // We get `Stdin` here. + loop { + let mut buffer = vec![0; 4096]; + if let Ok(n) = stdin.read(&mut buffer) { + if n > 0 { + let _ = pstdin.write_all(&buffer[..n]); + } + } + } + }); let mut ready = false; let start_time = Instant::now(); @@ -974,19 +991,36 @@ impl PythonLogMessage { } } -impl TryFrom<&String> for PythonLogMessage { +impl TryFrom<&[u8]> for PythonLogMessage { type Error = serde_json::Error; - fn try_from(value: &String) -> Result { - serde_json::from_str::(value) + fn try_from(value: &[u8]) -> Result { + serde_json::from_slice::(value) } } -fn log_lines(lines: Lines) { - for line in lines.map_while(Result::ok) { - match PythonLogMessage::try_from(&line) { - Ok(log) => log.trace(), - Err(_) => tracing::debug!("{line}"), +fn log_lines(mut bufread: BufReader) { + let mut buffer = vec![0u8; 4096]; + let mut stdout = std::io::stdout(); + loop { + let n = bufread.read(&mut buffer); + if let Ok(n) = n { + if n > 0 { + let mut lines = buffer[..n].split(|i| *i == b'\n').peekable(); + while let Some(line) = lines.next() { + match PythonLogMessage::try_from(line) { + Ok(log) => log.trace(), + // For interactive debugging ? + Err(_) => { + stdout.write_all(line).unwrap(); + if lines.peek().is_some() { + stdout.write_all(b"\n").unwrap(); + } + stdout.flush().unwrap(); + } + } + } + } } } } @@ -1146,7 +1180,7 @@ fn download_convert_model( let download_stdout = BufReader::new(download_process.stdout.take().unwrap()); thread::spawn(move || { - log_lines(download_stdout.lines()); + log_lines(download_stdout); }); let download_stderr = BufReader::new(download_process.stderr.take().unwrap()); diff --git a/router/src/lib.rs b/router/src/lib.rs index ce4f7c467..5a9779d55 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -22,6 +22,16 @@ pub enum Attention { FlashInfer, } +impl Attention { + pub fn block_size(&self) -> u32 { + match self { + Attention::FlashDecoding => 256, + Attention::FlashInfer => 1, + Attention::Paged => 16, + } + } +} + #[derive(Debug)] pub struct ParseError;