text-generation-inference/backends/v3/src/lib.rs

184 lines
6.1 KiB
Rust

mod backend;
pub mod block_allocator;
mod client;
mod queue;
pub mod radix;
use crate::client::{ClientError, ShardedClient};
pub(crate) use backend::BackendV3;
use serde::Serialize;
use thiserror::Error;
use utoipa::ToSchema;
#[derive(Clone, Debug, Serialize, ToSchema)]
pub struct BackendInfo {
/// Mandatory
#[schema(example = "cuda")]
pub model_device_type: String,
#[schema(example = "torch.float16")]
pub model_dtype: String,
/// Backend parameters
#[schema(example = "1")]
pub speculate: usize,
#[schema(example = "1.2")]
pub waiting_served_ratio: f32,
#[schema(example = "32000")]
pub max_batch_total_tokens: u32,
#[schema(example = "20")]
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
#[schema(example = "false")]
pub support_chunking: bool,
#[schema(example = "false")]
pub prefix_caching: bool,
#[schema(example = "flashinfer")]
pub attention_impl: String,
#[schema(example = "1")]
pub block_size: u32,
#[schema(example = "30000")]
pub max_input_tokens: usize,
#[schema(example = "32000")]
pub max_total_tokens: usize,
}
#[allow(clippy::too_many_arguments)]
pub async fn connect_backend(
max_input_tokens: Option<usize>,
max_total_tokens: Option<usize>,
master_shard_uds_path: String,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: Option<u32>,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
) -> Result<(BackendV3, BackendInfo), V3Error> {
// Helper function
let check_max_batch_total_tokens = |(
max_supported_batch_total_tokens,
shard_max_input_tokens,
shard_max_total_tokens,
): (Option<u32>, u32, u32)|
-> Result<(u32, usize, usize), V3Error> {
if let Some(max_input_tokens) = max_input_tokens {
assert_eq!(max_input_tokens as u32, shard_max_input_tokens);
}
if let Some(max_total_tokens) = max_total_tokens {
assert_eq!(max_total_tokens as u32, shard_max_total_tokens);
}
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000
.max(shard_max_total_tokens)
.max(max_batch_prefill_tokens),
);
tracing::warn!("Model does not support automatic max batch total tokens");
Ok((
max_batch_total_tokens,
shard_max_input_tokens as usize,
shard_max_total_tokens as usize,
))
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if shard_max_total_tokens > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(shard_max_total_tokens as usize));
}
Ok((
max_supported_batch_total_tokens,
shard_max_input_tokens as usize,
shard_max_total_tokens as usize,
))
}
}
};
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.map_err(V3Error::Connection)?;
// server is running on v3
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.map_err(V3Error::Cache)?;
// Get info from the shard
let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;
// Warmup model
tracing::info!("Warming up model");
let answer = sharded_client
.warmup(
max_input_tokens.map(|p| p as u32),
max_batch_prefill_tokens,
max_total_tokens.map(|p| p as u32),
max_batch_size,
)
.await
.map_err(V3Error::Warmup)?;
let (max_batch_total_tokens, max_input_tokens, max_total_tokens) =
check_max_batch_total_tokens(answer)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
let backend_info = BackendInfo {
waiting_served_ratio,
max_batch_total_tokens,
max_input_tokens,
max_total_tokens,
max_waiting_tokens,
max_batch_size,
model_device_type: shard_info.device_type.clone(),
model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize,
support_chunking: shard_info.support_chunking,
prefix_caching: shard_info.use_prefix_caching,
attention_impl: shard_info.attention_impl.clone(),
block_size: shard_info.block_size,
};
let backend = BackendV3::new(
sharded_client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
shard_info,
);
tracing::info!("Using backend V3");
Ok((backend, backend_info))
}
#[derive(Debug, Error)]
pub enum V3Error {
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Not enough memory to handle `max_total_tokens={0}`")]
NotEnoughMemory(usize),
}