From e47249a3ea1189c77051f7336814ca93d2a87413 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 10 Dec 2024 18:55:21 +0100 Subject: [PATCH] Much simpler solution. --- backends/v2/src/backend.rs | 4 +++ backends/v3/src/backend.rs | 4 +++ backends/v3/src/block_allocator.rs | 4 +-- backends/v3/src/client/mod.rs | 2 -- backends/v3/src/client/sharded_client.rs | 44 ++++-------------------- backends/v3/src/radix.rs | 4 +-- router/src/infer/mod.rs | 7 +++- 7 files changed, 25 insertions(+), 44 deletions(-) diff --git a/backends/v2/src/backend.rs b/backends/v2/src/backend.rs index bc264138..cfe87f98 100644 --- a/backends/v2/src/backend.rs +++ b/backends/v2/src/backend.rs @@ -104,6 +104,10 @@ impl Backend for BackendV2 { } .is_ok() } + + fn start_health(&self) -> bool { + true + } } /// Batching logic diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 7ae794a0..736301b3 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -111,6 +111,10 @@ impl Backend for BackendV3 { } .is_ok() } + + fn start_health(&self) -> bool { + true + } } /// Batching logic diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index d54bb0ee..4fea172b 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -147,8 +147,8 @@ impl SimpleAllocator { fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { SimpleAllocator { block_size, - // XXX: Block 0&1 is reserved for health checks - free_blocks: (2..blocks).collect(), + // Block 0 is reserved for health checks + free_blocks: (1..blocks).collect(), window_size, } } diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs index 9c72be64..d4ac50c9 100644 --- a/backends/v3/src/client/mod.rs +++ b/backends/v3/src/client/mod.rs @@ -37,8 +37,6 @@ pub enum ClientError { Generation(String), #[error("Sharded results are empty")] EmptyResults, - #[error("Invalid attention {0}")] - InvalidAttention(String), } impl From for ClientError { diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index acec4dc0..4701c560 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -13,38 +13,15 @@ use futures::future::join_all; use tonic::transport::Uri; use tracing::instrument; -#[derive(Debug, Clone, Copy)] -pub enum Attn { - Flashdecoding, - Flashinfer, - Paged, -} - -impl TryFrom<&str> for Attn { - type Error = ClientError; - fn try_from(value: &str) -> std::result::Result { - match value { - "flashdecoding" => Ok(Attn::Flashdecoding), - "flashinfer" => Ok(Attn::Flashinfer), - "paged" => Ok(Attn::Paged), - string => Err(ClientError::InvalidAttention(string.to_string())), - } - } -} - #[derive(Debug, Clone)] /// Text Generation Inference gRPC multi client pub struct ShardedClient { clients: Vec, - attention_impl: Option, } impl ShardedClient { fn new(clients: Vec) -> Self { - Self { - clients, - attention_impl: None, - } + Self { clients } } /// Create a new ShardedClient from a master client. The master client will communicate with @@ -78,9 +55,7 @@ impl ShardedClient { .iter_mut() .map(|client| client.info()) .collect(); - let info = join_all(futures).await.pop().unwrap()?; - self.attention_impl = Some((&*info.attention_impl).try_into()?); - Ok(info) + join_all(futures).await.pop().unwrap() } /// GRPC health check @@ -236,20 +211,14 @@ impl Health for ShardedClient { async fn model_health(&self) -> Result<()> { // Dummy batch of 1 token and 1 generated token - // - let (blocks, slots) = match self.attention_impl.expect("Attention to be set") { - Attn::Paged => (vec![0], (0..2).collect()), - Attn::Flashinfer => (vec![0, 1], (0..2).collect()), - Attn::Flashdecoding => (vec![0], (0..2).collect()), - }; let liveness_request = Request { id: u64::MAX, inputs: "liveness".to_string(), input_chunks: Some(Input { chunks: vec![Chunk::Text("liveness".into()).into()], }), - truncate: 10, - add_special_tokens: true, + truncate: 1, + add_special_tokens: false, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, @@ -270,8 +239,9 @@ impl Health for ShardedClient { ignore_eos_token: false, }), top_n_tokens: 0, - blocks, - slots, + // Block 0 is reserved for health checks + blocks: vec![0], + slots: vec![0], cache_len: 0, adapter_id: None, chunk_len: None, diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index cfa32002..8a544891 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -42,8 +42,8 @@ impl RadixAllocator { allocations: HashMap::new(), cache_blocks: RadixTrie::new(block_size as usize), - // XXX: Block 0,1 is reserved for health checks. - free_blocks: (2..n_blocks).collect(), + // Block 0 is reserved for health checks. + free_blocks: (1..n_blocks).collect(), window_size, block_size, } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 86606643..0638d0c6 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -33,6 +33,11 @@ pub trait Backend { ) -> Result>, InferError>; async fn health(&self, current_health: bool) -> bool; + + /// The state of the health on startup + /// Typically false, or true if the backend includes + /// a warmup phase. + fn start_health(&self) -> bool; } /// Inference struct @@ -75,7 +80,7 @@ impl Infer { let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); // Backend health - let backend_health = Arc::new(AtomicBool::new(false)); + let backend_health = Arc::new(AtomicBool::new(backend.start_health())); Self { validation,