mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Much simpler solution.
This commit is contained in:
parent
45a86d5cf0
commit
e47249a3ea
@ -104,6 +104,10 @@ impl Backend for BackendV2 {
|
|||||||
}
|
}
|
||||||
.is_ok()
|
.is_ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn start_health(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
|
@ -111,6 +111,10 @@ impl Backend for BackendV3 {
|
|||||||
}
|
}
|
||||||
.is_ok()
|
.is_ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn start_health(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
|
@ -147,8 +147,8 @@ impl SimpleAllocator {
|
|||||||
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {
|
||||||
SimpleAllocator {
|
SimpleAllocator {
|
||||||
block_size,
|
block_size,
|
||||||
// XXX: Block 0&1 is reserved for health checks
|
// Block 0 is reserved for health checks
|
||||||
free_blocks: (2..blocks).collect(),
|
free_blocks: (1..blocks).collect(),
|
||||||
window_size,
|
window_size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -37,8 +37,6 @@ pub enum ClientError {
|
|||||||
Generation(String),
|
Generation(String),
|
||||||
#[error("Sharded results are empty")]
|
#[error("Sharded results are empty")]
|
||||||
EmptyResults,
|
EmptyResults,
|
||||||
#[error("Invalid attention {0}")]
|
|
||||||
InvalidAttention(String),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Status> for ClientError {
|
impl From<Status> for ClientError {
|
||||||
|
@ -13,38 +13,15 @@ use futures::future::join_all;
|
|||||||
use tonic::transport::Uri;
|
use tonic::transport::Uri;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub enum Attn {
|
|
||||||
Flashdecoding,
|
|
||||||
Flashinfer,
|
|
||||||
Paged,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<&str> for Attn {
|
|
||||||
type Error = ClientError;
|
|
||||||
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
|
|
||||||
match value {
|
|
||||||
"flashdecoding" => Ok(Attn::Flashdecoding),
|
|
||||||
"flashinfer" => Ok(Attn::Flashinfer),
|
|
||||||
"paged" => Ok(Attn::Paged),
|
|
||||||
string => Err(ClientError::InvalidAttention(string.to_string())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
/// Text Generation Inference gRPC multi client
|
/// Text Generation Inference gRPC multi client
|
||||||
pub struct ShardedClient {
|
pub struct ShardedClient {
|
||||||
clients: Vec<Client>,
|
clients: Vec<Client>,
|
||||||
attention_impl: Option<Attn>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShardedClient {
|
impl ShardedClient {
|
||||||
fn new(clients: Vec<Client>) -> Self {
|
fn new(clients: Vec<Client>) -> Self {
|
||||||
Self {
|
Self { clients }
|
||||||
clients,
|
|
||||||
attention_impl: None,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new ShardedClient from a master client. The master client will communicate with
|
/// Create a new ShardedClient from a master client. The master client will communicate with
|
||||||
@ -78,9 +55,7 @@ impl ShardedClient {
|
|||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| client.info())
|
.map(|client| client.info())
|
||||||
.collect();
|
.collect();
|
||||||
let info = join_all(futures).await.pop().unwrap()?;
|
join_all(futures).await.pop().unwrap()
|
||||||
self.attention_impl = Some((&*info.attention_impl).try_into()?);
|
|
||||||
Ok(info)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GRPC health check
|
/// GRPC health check
|
||||||
@ -236,20 +211,14 @@ impl Health for ShardedClient {
|
|||||||
|
|
||||||
async fn model_health(&self) -> Result<()> {
|
async fn model_health(&self) -> Result<()> {
|
||||||
// Dummy batch of 1 token and 1 generated token
|
// Dummy batch of 1 token and 1 generated token
|
||||||
//
|
|
||||||
let (blocks, slots) = match self.attention_impl.expect("Attention to be set") {
|
|
||||||
Attn::Paged => (vec![0], (0..2).collect()),
|
|
||||||
Attn::Flashinfer => (vec![0, 1], (0..2).collect()),
|
|
||||||
Attn::Flashdecoding => (vec![0], (0..2).collect()),
|
|
||||||
};
|
|
||||||
let liveness_request = Request {
|
let liveness_request = Request {
|
||||||
id: u64::MAX,
|
id: u64::MAX,
|
||||||
inputs: "liveness".to_string(),
|
inputs: "liveness".to_string(),
|
||||||
input_chunks: Some(Input {
|
input_chunks: Some(Input {
|
||||||
chunks: vec![Chunk::Text("liveness".into()).into()],
|
chunks: vec![Chunk::Text("liveness".into()).into()],
|
||||||
}),
|
}),
|
||||||
truncate: 10,
|
truncate: 1,
|
||||||
add_special_tokens: true,
|
add_special_tokens: false,
|
||||||
prefill_logprobs: false,
|
prefill_logprobs: false,
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
@ -270,8 +239,9 @@ impl Health for ShardedClient {
|
|||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
}),
|
}),
|
||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
blocks,
|
// Block 0 is reserved for health checks
|
||||||
slots,
|
blocks: vec![0],
|
||||||
|
slots: vec![0],
|
||||||
cache_len: 0,
|
cache_len: 0,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
chunk_len: None,
|
chunk_len: None,
|
||||||
|
@ -42,8 +42,8 @@ impl RadixAllocator {
|
|||||||
allocations: HashMap::new(),
|
allocations: HashMap::new(),
|
||||||
cache_blocks: RadixTrie::new(block_size as usize),
|
cache_blocks: RadixTrie::new(block_size as usize),
|
||||||
|
|
||||||
// XXX: Block 0,1 is reserved for health checks.
|
// Block 0 is reserved for health checks.
|
||||||
free_blocks: (2..n_blocks).collect(),
|
free_blocks: (1..n_blocks).collect(),
|
||||||
window_size,
|
window_size,
|
||||||
block_size,
|
block_size,
|
||||||
}
|
}
|
||||||
|
@ -33,6 +33,11 @@ pub trait Backend {
|
|||||||
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
|
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;
|
||||||
|
|
||||||
async fn health(&self, current_health: bool) -> bool;
|
async fn health(&self, current_health: bool) -> bool;
|
||||||
|
|
||||||
|
/// The state of the health on startup
|
||||||
|
/// Typically false, or true if the backend includes
|
||||||
|
/// a warmup phase.
|
||||||
|
fn start_health(&self) -> bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inference struct
|
/// Inference struct
|
||||||
@ -75,7 +80,7 @@ impl Infer {
|
|||||||
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
|
||||||
|
|
||||||
// Backend health
|
// Backend health
|
||||||
let backend_health = Arc::new(AtomicBool::new(false));
|
let backend_health = Arc::new(AtomicBool::new(backend.start_health()));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
validation,
|
validation,
|
||||||
|
Loading…
Reference in New Issue
Block a user