Much simpler solution.

This commit is contained in:
Nicolas Patry 2024-12-10 18:55:21 +01:00
parent 45a86d5cf0
commit e47249a3ea
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
7 changed files with 25 additions and 44 deletions

View File

@ -104,6 +104,10 @@ impl Backend for BackendV2 {
} }
.is_ok() .is_ok()
} }
fn start_health(&self) -> bool {
true
}
} }
/// Batching logic /// Batching logic

View File

@ -111,6 +111,10 @@ impl Backend for BackendV3 {
} }
.is_ok() .is_ok()
} }
fn start_health(&self) -> bool {
true
}
} }
/// Batching logic /// Batching logic

View File

@ -147,8 +147,8 @@ impl SimpleAllocator {
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self { fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
SimpleAllocator { SimpleAllocator {
block_size, block_size,
// XXX: Block 0&1 is reserved for health checks // Block 0 is reserved for health checks
free_blocks: (2..blocks).collect(), free_blocks: (1..blocks).collect(),
window_size, window_size,
} }
} }

View File

@ -37,8 +37,6 @@ pub enum ClientError {
Generation(String), Generation(String),
#[error("Sharded results are empty")] #[error("Sharded results are empty")]
EmptyResults, EmptyResults,
#[error("Invalid attention {0}")]
InvalidAttention(String),
} }
impl From<Status> for ClientError { impl From<Status> for ClientError {

View File

@ -13,38 +13,15 @@ use futures::future::join_all;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
#[derive(Debug, Clone, Copy)]
pub enum Attn {
Flashdecoding,
Flashinfer,
Paged,
}
impl TryFrom<&str> for Attn {
type Error = ClientError;
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
match value {
"flashdecoding" => Ok(Attn::Flashdecoding),
"flashinfer" => Ok(Attn::Flashinfer),
"paged" => Ok(Attn::Paged),
string => Err(ClientError::InvalidAttention(string.to_string())),
}
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
/// Text Generation Inference gRPC multi client /// Text Generation Inference gRPC multi client
pub struct ShardedClient { pub struct ShardedClient {
clients: Vec<Client>, clients: Vec<Client>,
attention_impl: Option<Attn>,
} }
impl ShardedClient { impl ShardedClient {
fn new(clients: Vec<Client>) -> Self { fn new(clients: Vec<Client>) -> Self {
Self { Self { clients }
clients,
attention_impl: None,
}
} }
/// Create a new ShardedClient from a master client. The master client will communicate with /// Create a new ShardedClient from a master client. The master client will communicate with
@ -78,9 +55,7 @@ impl ShardedClient {
.iter_mut() .iter_mut()
.map(|client| client.info()) .map(|client| client.info())
.collect(); .collect();
let info = join_all(futures).await.pop().unwrap()?; join_all(futures).await.pop().unwrap()
self.attention_impl = Some((&*info.attention_impl).try_into()?);
Ok(info)
} }
/// GRPC health check /// GRPC health check
@ -236,20 +211,14 @@ impl Health for ShardedClient {
async fn model_health(&self) -> Result<()> { async fn model_health(&self) -> Result<()> {
// Dummy batch of 1 token and 1 generated token // Dummy batch of 1 token and 1 generated token
//
let (blocks, slots) = match self.attention_impl.expect("Attention to be set") {
Attn::Paged => (vec![0], (0..2).collect()),
Attn::Flashinfer => (vec![0, 1], (0..2).collect()),
Attn::Flashdecoding => (vec![0], (0..2).collect()),
};
let liveness_request = Request { let liveness_request = Request {
id: u64::MAX, id: u64::MAX,
inputs: "liveness".to_string(), inputs: "liveness".to_string(),
input_chunks: Some(Input { input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()], chunks: vec![Chunk::Text("liveness".into()).into()],
}), }),
truncate: 10, truncate: 1,
add_special_tokens: true, add_special_tokens: false,
prefill_logprobs: false, prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 1.0, temperature: 1.0,
@ -270,8 +239,9 @@ impl Health for ShardedClient {
ignore_eos_token: false, ignore_eos_token: false,
}), }),
top_n_tokens: 0, top_n_tokens: 0,
blocks, // Block 0 is reserved for health checks
slots, blocks: vec![0],
slots: vec![0],
cache_len: 0, cache_len: 0,
adapter_id: None, adapter_id: None,
chunk_len: None, chunk_len: None,

View File

@ -42,8 +42,8 @@ impl RadixAllocator {
allocations: HashMap::new(), allocations: HashMap::new(),
cache_blocks: RadixTrie::new(block_size as usize), cache_blocks: RadixTrie::new(block_size as usize),
// XXX: Block 0,1 is reserved for health checks. // Block 0 is reserved for health checks.
free_blocks: (2..n_blocks).collect(), free_blocks: (1..n_blocks).collect(),
window_size, window_size,
block_size, block_size,
} }

View File

@ -33,6 +33,11 @@ pub trait Backend {
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>; ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
async fn health(&self, current_health: bool) -> bool; async fn health(&self, current_health: bool) -> bool;
/// The state of the health on startup
/// Typically false, or true if the backend includes
/// a warmup phase.
fn start_health(&self) -> bool;
} }
/// Inference struct /// Inference struct
@ -75,7 +80,7 @@ impl Infer {
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests)); let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
// Backend health // Backend health
let backend_health = Arc::new(AtomicBool::new(false)); let backend_health = Arc::new(AtomicBool::new(backend.start_health()));
Self { Self {
validation, validation,