From 45a86d5cf0cc93f3f16082ff163d2cbfe6a55aa3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 10 Dec 2024 17:57:17 +0100 Subject: [PATCH] Simple attempt to fix the healthcheck block allocation. --- backends/v3/src/block_allocator.rs | 4 +-- backends/v3/src/client/mod.rs | 2 ++ backends/v3/src/client/sharded_client.rs | 40 +++++++++++++++++++++--- backends/v3/src/radix.rs | 4 +-- 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index 4fea172b..d54bb0ee 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -147,8 +147,8 @@ impl SimpleAllocator { fn new(blocks: u32, block_size: u32, window_size: Option) -> Self { SimpleAllocator { block_size, - // Block 0 is reserved for health checks - free_blocks: (1..blocks).collect(), + // XXX: Block 0&1 is reserved for health checks + free_blocks: (2..blocks).collect(), window_size, } } diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs index d4ac50c9..9c72be64 100644 --- a/backends/v3/src/client/mod.rs +++ b/backends/v3/src/client/mod.rs @@ -37,6 +37,8 @@ pub enum ClientError { Generation(String), #[error("Sharded results are empty")] EmptyResults, + #[error("Invalid attention {0}")] + InvalidAttention(String), } impl From for ClientError { diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 6d4e207b..acec4dc0 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -13,15 +13,38 @@ use futures::future::join_all; use tonic::transport::Uri; use tracing::instrument; +#[derive(Debug, Clone, Copy)] +pub enum Attn { + Flashdecoding, + Flashinfer, + Paged, +} + +impl TryFrom<&str> for Attn { + type Error = ClientError; + fn try_from(value: &str) -> std::result::Result { + match value { + "flashdecoding" => Ok(Attn::Flashdecoding), + "flashinfer" => Ok(Attn::Flashinfer), + "paged" => Ok(Attn::Paged), + string => Err(ClientError::InvalidAttention(string.to_string())), + } + } +} + #[derive(Debug, Clone)] /// Text Generation Inference gRPC multi client pub struct ShardedClient { clients: Vec, + attention_impl: Option, } impl ShardedClient { fn new(clients: Vec) -> Self { - Self { clients } + Self { + clients, + attention_impl: None, + } } /// Create a new ShardedClient from a master client. The master client will communicate with @@ -55,7 +78,9 @@ impl ShardedClient { .iter_mut() .map(|client| client.info()) .collect(); - join_all(futures).await.pop().unwrap() + let info = join_all(futures).await.pop().unwrap()?; + self.attention_impl = Some((&*info.attention_impl).try_into()?); + Ok(info) } /// GRPC health check @@ -211,6 +236,12 @@ impl Health for ShardedClient { async fn model_health(&self) -> Result<()> { // Dummy batch of 1 token and 1 generated token + // + let (blocks, slots) = match self.attention_impl.expect("Attention to be set") { + Attn::Paged => (vec![0], (0..2).collect()), + Attn::Flashinfer => (vec![0, 1], (0..2).collect()), + Attn::Flashdecoding => (vec![0], (0..2).collect()), + }; let liveness_request = Request { id: u64::MAX, inputs: "liveness".to_string(), @@ -239,9 +270,8 @@ impl Health for ShardedClient { ignore_eos_token: false, }), top_n_tokens: 0, - // Block 0 is reserved for health checks - blocks: vec![0], - slots: (0..16).collect(), + blocks, + slots, cache_len: 0, adapter_id: None, chunk_len: None, diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 8a544891..cfa32002 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -42,8 +42,8 @@ impl RadixAllocator { allocations: HashMap::new(), cache_blocks: RadixTrie::new(block_size as usize), - // Block 0 is reserved for health checks. - free_blocks: (1..n_blocks).collect(), + // XXX: Block 0,1 is reserved for health checks. + free_blocks: (2..n_blocks).collect(), window_size, block_size, }