mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Simple attempt to fix the healthcheck block allocation.
This commit is contained in:
parent
22dbc449cd
commit
45a86d5cf0
@ -147,8 +147,8 @@ impl SimpleAllocator {
|
||||
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
||||
SimpleAllocator {
|
||||
block_size,
|
||||
// Block 0 is reserved for health checks
|
||||
free_blocks: (1..blocks).collect(),
|
||||
// XXX: Block 0&1 is reserved for health checks
|
||||
free_blocks: (2..blocks).collect(),
|
||||
window_size,
|
||||
}
|
||||
}
|
||||
|
@ -37,6 +37,8 @@ pub enum ClientError {
|
||||
Generation(String),
|
||||
#[error("Sharded results are empty")]
|
||||
EmptyResults,
|
||||
#[error("Invalid attention {0}")]
|
||||
InvalidAttention(String),
|
||||
}
|
||||
|
||||
impl From<Status> for ClientError {
|
||||
|
@ -13,15 +13,38 @@ use futures::future::join_all;
|
||||
use tonic::transport::Uri;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Attn {
|
||||
Flashdecoding,
|
||||
Flashinfer,
|
||||
Paged,
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for Attn {
|
||||
type Error = ClientError;
|
||||
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
|
||||
match value {
|
||||
"flashdecoding" => Ok(Attn::Flashdecoding),
|
||||
"flashinfer" => Ok(Attn::Flashinfer),
|
||||
"paged" => Ok(Attn::Paged),
|
||||
string => Err(ClientError::InvalidAttention(string.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Text Generation Inference gRPC multi client
|
||||
pub struct ShardedClient {
|
||||
clients: Vec<Client>,
|
||||
attention_impl: Option<Attn>,
|
||||
}
|
||||
|
||||
impl ShardedClient {
|
||||
fn new(clients: Vec<Client>) -> Self {
|
||||
Self { clients }
|
||||
Self {
|
||||
clients,
|
||||
attention_impl: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||
@ -55,7 +78,9 @@ impl ShardedClient {
|
||||
.iter_mut()
|
||||
.map(|client| client.info())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap()
|
||||
let info = join_all(futures).await.pop().unwrap()?;
|
||||
self.attention_impl = Some((&*info.attention_impl).try_into()?);
|
||||
Ok(info)
|
||||
}
|
||||
|
||||
/// GRPC health check
|
||||
@ -211,6 +236,12 @@ impl Health for ShardedClient {
|
||||
|
||||
async fn model_health(&self) -> Result<()> {
|
||||
// Dummy batch of 1 token and 1 generated token
|
||||
//
|
||||
let (blocks, slots) = match self.attention_impl.expect("Attention to be set") {
|
||||
Attn::Paged => (vec![0], (0..2).collect()),
|
||||
Attn::Flashinfer => (vec![0, 1], (0..2).collect()),
|
||||
Attn::Flashdecoding => (vec![0], (0..2).collect()),
|
||||
};
|
||||
let liveness_request = Request {
|
||||
id: u64::MAX,
|
||||
inputs: "liveness".to_string(),
|
||||
@ -239,9 +270,8 @@ impl Health for ShardedClient {
|
||||
ignore_eos_token: false,
|
||||
}),
|
||||
top_n_tokens: 0,
|
||||
// Block 0 is reserved for health checks
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
blocks,
|
||||
slots,
|
||||
cache_len: 0,
|
||||
adapter_id: None,
|
||||
chunk_len: None,
|
||||
|
@ -42,8 +42,8 @@ impl RadixAllocator {
|
||||
allocations: HashMap::new(),
|
||||
cache_blocks: RadixTrie::new(block_size as usize),
|
||||
|
||||
// Block 0 is reserved for health checks.
|
||||
free_blocks: (1..n_blocks).collect(),
|
||||
// XXX: Block 0,1 is reserved for health checks.
|
||||
free_blocks: (2..n_blocks).collect(),
|
||||
window_size,
|
||||
block_size,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user