mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
load tested
This commit is contained in:
parent
34f5dc525e
commit
7f9abde3f8
@ -159,6 +159,7 @@ impl Client {
|
|||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
prefix_len: 0,
|
||||||
|
postfix_len: truncate,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
@ -246,6 +246,7 @@ impl Health for ShardedClient {
|
|||||||
blocks: vec![0],
|
blocks: vec![0],
|
||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
prefix_len: 0,
|
prefix_len: 0,
|
||||||
|
postfix_len: 1,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
|
@ -34,9 +34,13 @@ impl BackendV3 {
|
|||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
support_chunking: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let prefix_caching =
|
if support_chunking {
|
||||||
std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
|
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
|
||||||
|
}
|
||||||
|
|
||||||
|
let prefix_caching = std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
|
||||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
||||||
let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string());
|
let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string());
|
||||||
|
|
||||||
@ -52,6 +56,7 @@ impl BackendV3 {
|
|||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
support_chunking,
|
||||||
);
|
);
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
@ -63,6 +68,7 @@ impl BackendV3 {
|
|||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
|
support_chunking,
|
||||||
queue.clone(),
|
queue.clone(),
|
||||||
batching_task_notifier.clone(),
|
batching_task_notifier.clone(),
|
||||||
));
|
));
|
||||||
@ -127,6 +133,7 @@ pub(crate) async fn batching_task(
|
|||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
max_batch_size: Option<usize>,
|
max_batch_size: Option<usize>,
|
||||||
|
support_chunking: bool,
|
||||||
queue: Queue,
|
queue: Queue,
|
||||||
notifier: Arc<Notify>,
|
notifier: Arc<Notify>,
|
||||||
) {
|
) {
|
||||||
@ -158,28 +165,44 @@ pub(crate) async fn batching_task(
|
|||||||
// Get current batch info
|
// Get current batch info
|
||||||
let batch_size = batch.size;
|
let batch_size = batch.size;
|
||||||
let batch_max_tokens = batch.max_tokens;
|
let batch_max_tokens = batch.max_tokens;
|
||||||
|
let current_tokens = batch.current_tokens;
|
||||||
let mut batches = vec![batch];
|
let mut batches = vec![batch];
|
||||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||||
|
|
||||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
|
||||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
|
||||||
// to add a new batch even though its size might be small
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
// Minimum batch size
|
|
||||||
// TODO: temporarily disable to avoid incorrect deallocation +
|
|
||||||
// reallocation when using prefix caching.
|
|
||||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
|
||||||
};
|
|
||||||
|
|
||||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||||
let max_size =
|
|
||||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
||||||
|
// Since the next batch will be concatenated with the current batch,
|
||||||
|
// the current batch tokens must be subtracted to the prefill budget
|
||||||
|
// In the future, we could concatenate beforehand
|
||||||
|
let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
|
||||||
|
// We can ignore min_size and max_size
|
||||||
|
// Models than rely on max_size cannot support chunking
|
||||||
|
// Regarding min_size, chunking allow us to consistently run at the compute
|
||||||
|
// bound, making min_size useless.
|
||||||
|
(None, None, prefill_token_budget)
|
||||||
|
} else {
|
||||||
|
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||||
|
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||||
|
// to add a new batch even though its size might be small
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
// Minimum batch size
|
||||||
|
// TODO: temporarily disable to avoid incorrect deallocation +
|
||||||
|
// reallocation when using prefix caching.
|
||||||
|
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||||
|
};
|
||||||
|
|
||||||
|
let max_size =
|
||||||
|
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||||
|
|
||||||
|
(min_size, max_size, max_batch_prefill_tokens)
|
||||||
|
};
|
||||||
|
|
||||||
// Try to get a new batch
|
// Try to get a new batch
|
||||||
if let Some((mut new_entries, new_batch, span)) = queue
|
if let Some((mut new_entries, new_batch, span)) = queue
|
||||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
|
@ -159,6 +159,7 @@ impl Client {
|
|||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
prefix_len: 0,
|
||||||
|
postfix_len: truncate,
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
|
@ -29,15 +29,6 @@ pub trait Health {
|
|||||||
async fn model_health(&self) -> Result<()>;
|
async fn model_health(&self) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ShardInfo {
|
|
||||||
pub requires_padding: bool,
|
|
||||||
pub dtype: String,
|
|
||||||
pub device_type: String,
|
|
||||||
pub window_size: Option<u32>,
|
|
||||||
pub speculate: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Error, Debug, Clone)]
|
#[derive(Error, Debug, Clone)]
|
||||||
pub enum ClientError {
|
pub enum ClientError {
|
||||||
#[error("Could not connect to Text Generation server: {0}")]
|
#[error("Could not connect to Text Generation server: {0}")]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::client::{ClientError, Result};
|
use crate::client::Health;
|
||||||
/// Multi shard Client
|
/// Multi shard Client
|
||||||
use crate::client::{Health, ShardInfo};
|
use crate::client::{ClientError, Result};
|
||||||
|
|
||||||
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||||
use crate::client::{
|
use crate::client::{
|
||||||
@ -49,13 +49,13 @@ impl ShardedClient {
|
|||||||
|
|
||||||
/// Get the model info
|
/// Get the model info
|
||||||
#[instrument(skip(self))]
|
#[instrument(skip(self))]
|
||||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||||
let futures: Vec<_> = self
|
let futures: Vec<_> = self
|
||||||
.clients
|
.clients
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
.map(|client| client.info())
|
.map(|client| client.info())
|
||||||
.collect();
|
.collect();
|
||||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
join_all(futures).await.pop().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GRPC health check
|
/// GRPC health check
|
||||||
@ -194,18 +194,6 @@ impl ShardedClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<InfoResponse> for ShardInfo {
|
|
||||||
fn from(value: InfoResponse) -> Self {
|
|
||||||
Self {
|
|
||||||
requires_padding: value.requires_padding,
|
|
||||||
dtype: value.dtype,
|
|
||||||
device_type: value.device_type,
|
|
||||||
window_size: value.window_size,
|
|
||||||
speculate: value.speculate,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Health for ShardedClient {
|
impl Health for ShardedClient {
|
||||||
async fn device_health(&self) -> Result<()> {
|
async fn device_health(&self) -> Result<()> {
|
||||||
@ -248,6 +236,7 @@ impl Health for ShardedClient {
|
|||||||
slots: (0..16).collect(),
|
slots: (0..16).collect(),
|
||||||
prefix_len: 0,
|
prefix_len: 0,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
|
postfix_len: 1,
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: u64::MAX,
|
id: u64::MAX,
|
||||||
|
@ -29,6 +29,8 @@ pub struct BackendInfo {
|
|||||||
pub max_waiting_tokens: usize,
|
pub max_waiting_tokens: usize,
|
||||||
#[schema(nullable = true, example = "null")]
|
#[schema(nullable = true, example = "null")]
|
||||||
pub max_batch_size: Option<usize>,
|
pub max_batch_size: Option<usize>,
|
||||||
|
#[schema(example = "false")]
|
||||||
|
pub support_chunking: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
@ -110,6 +112,7 @@ pub async fn connect_backend(
|
|||||||
model_device_type: shard_info.device_type.clone(),
|
model_device_type: shard_info.device_type.clone(),
|
||||||
model_dtype: shard_info.dtype.clone(),
|
model_dtype: shard_info.dtype.clone(),
|
||||||
speculate: shard_info.speculate as usize,
|
speculate: shard_info.speculate as usize,
|
||||||
|
support_chunking: shard_info.support_chunking,
|
||||||
};
|
};
|
||||||
|
|
||||||
let backend = BackendV3::new(
|
let backend = BackendV3::new(
|
||||||
@ -122,6 +125,7 @@ pub async fn connect_backend(
|
|||||||
shard_info.requires_padding,
|
shard_info.requires_padding,
|
||||||
shard_info.window_size,
|
shard_info.window_size,
|
||||||
shard_info.speculate,
|
shard_info.speculate,
|
||||||
|
shard_info.support_chunking,
|
||||||
);
|
);
|
||||||
|
|
||||||
tracing::info!("Using backend V3");
|
tracing::info!("Using backend V3");
|
||||||
|
@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
|
||||||
}
|
|
||||||
|
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
"`validation_workers` must be > 0".to_string(),
|
"`validation_workers` must be > 0".to_string(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
|
||||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
|
||||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(max_batch_size) = max_batch_size {
|
if let Some(max_batch_size) = max_batch_size {
|
||||||
if max_batch_size == 0 {
|
if max_batch_size == 0 {
|
||||||
return Err(RouterError::ArgumentValidation(
|
return Err(RouterError::ArgumentValidation(
|
||||||
@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (backend, _backend_info) = connect_backend(
|
let (backend, backend_info) = connect_backend(
|
||||||
max_input_tokens,
|
max_input_tokens,
|
||||||
max_total_tokens,
|
max_total_tokens,
|
||||||
master_shard_uds_path,
|
master_shard_uds_path,
|
||||||
@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> {
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
// Validate remaining args now that the backend is known
|
||||||
|
let support_chunking = backend_info.support_chunking;
|
||||||
|
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
|
||||||
|
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||||
|
}
|
||||||
|
if max_batch_prefill_tokens > max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
if max_total_tokens as u32 > max_batch_total_tokens {
|
||||||
|
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||||
|
}
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
server::run(
|
server::run(
|
||||||
backend,
|
backend,
|
||||||
|
@ -4,7 +4,7 @@ use crate::client::{
|
|||||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::{max, min};
|
use std::cmp::max;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_router::infer::InferError;
|
use text_generation_router::infer::InferError;
|
||||||
use text_generation_router::infer::InferStreamResponse;
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
@ -50,6 +50,7 @@ impl Queue {
|
|||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
support_chunking: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
@ -62,6 +63,7 @@ impl Queue {
|
|||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
support_chunking,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
));
|
));
|
||||||
|
|
||||||
@ -108,6 +110,7 @@ impl Queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Background task responsible of the queue state
|
// Background task responsible of the queue state
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn queue_task(
|
async fn queue_task(
|
||||||
requires_padding: bool,
|
requires_padding: bool,
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
@ -115,6 +118,7 @@ async fn queue_task(
|
|||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
support_chunking: bool,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(
|
let mut state = State::new(
|
||||||
@ -124,6 +128,7 @@ async fn queue_task(
|
|||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
|
support_chunking,
|
||||||
);
|
);
|
||||||
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
@ -166,12 +171,14 @@ struct State {
|
|||||||
/// Paged Attention block size
|
/// Paged Attention block size
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
|
||||||
/// Sliding window
|
|
||||||
window_size: Option<u32>,
|
|
||||||
|
|
||||||
/// Speculation amount
|
/// Speculation amount
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
|
||||||
|
/// Whether the model allow the prefill chunking
|
||||||
|
/// If it does, the last request in the batch will be split to exactly match the prefill
|
||||||
|
/// token budget
|
||||||
|
support_chunking: bool,
|
||||||
|
|
||||||
/// Paged Attention Block Allocation
|
/// Paged Attention Block Allocation
|
||||||
block_allocator: Option<BlockAllocator>,
|
block_allocator: Option<BlockAllocator>,
|
||||||
}
|
}
|
||||||
@ -184,6 +191,7 @@ impl State {
|
|||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
max_batch_total_tokens: u32,
|
max_batch_total_tokens: u32,
|
||||||
|
support_chunking: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let block_allocator = (!requires_padding).then(|| {
|
let block_allocator = (!requires_padding).then(|| {
|
||||||
BlockAllocator::new(
|
BlockAllocator::new(
|
||||||
@ -199,8 +207,8 @@ impl State {
|
|||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
|
||||||
speculate,
|
speculate,
|
||||||
|
support_chunking,
|
||||||
block_allocator,
|
block_allocator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -268,7 +276,7 @@ impl State {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let block_allocation = match &self.block_allocator {
|
let (block_allocation, postfix_len) = match &self.block_allocator {
|
||||||
None => {
|
None => {
|
||||||
// We pad to max input length in the Python shards
|
// We pad to max input length in the Python shards
|
||||||
// We need to take these padding tokens into the equation
|
// We need to take these padding tokens into the equation
|
||||||
@ -285,34 +293,9 @@ impl State {
|
|||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
break 'entry_loop;
|
break 'entry_loop;
|
||||||
}
|
}
|
||||||
None
|
(None, entry.request.input_length)
|
||||||
}
|
}
|
||||||
Some(_block_allocator) => {
|
Some(block_allocator) => {
|
||||||
prefill_tokens += entry.request.input_length;
|
|
||||||
let max_new_tokens = match self.window_size {
|
|
||||||
None => entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
Some(window_size) => min(
|
|
||||||
window_size.saturating_sub(entry.request.input_length),
|
|
||||||
entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
decode_tokens += max_new_tokens;
|
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
|
||||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
|
||||||
{
|
|
||||||
// Entry is over budget
|
|
||||||
// Add it back to the front
|
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let tokens = entry.request.input_length
|
|
||||||
+ entry.request.stopping_parameters.max_new_tokens
|
|
||||||
+ self.speculate
|
|
||||||
- 1;
|
|
||||||
|
|
||||||
// If users wants the prefill logprobs, we cannot reuse the cache.
|
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||||
// So no input_ids for the radix tree.
|
// So no input_ids for the radix tree.
|
||||||
let input_ids = if entry.request.decoder_input_details {
|
let input_ids = if entry.request.decoder_input_details {
|
||||||
@ -321,10 +304,65 @@ impl State {
|
|||||||
entry.request.input_ids.clone()
|
entry.request.input_ids.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
Some((tokens, input_ids))
|
let tokens = entry.request.input_length
|
||||||
|
+ entry.request.stopping_parameters.max_new_tokens
|
||||||
|
+ self.speculate
|
||||||
|
- 1;
|
||||||
|
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||||
|
|
||||||
|
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
|
||||||
|
None => {
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: not enough free blocks");
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break 'entry_loop;
|
||||||
|
}
|
||||||
|
Some(mut block_allocation) => {
|
||||||
|
tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
|
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||||
|
|
||||||
|
if block_allocation.prefix_len == entry.request.input_length {
|
||||||
|
// The whole request was found in the radix trie
|
||||||
|
// However, for the transformer forward to work, we need to
|
||||||
|
// have at least one token of postfix.
|
||||||
|
block_allocation.prefix_len -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
block_allocation
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut postfix_len = entry.request.input_length - block_allocation.prefix_len;
|
||||||
|
|
||||||
|
// Check equality too as if we don't we might end up with a postfix_len = 0
|
||||||
|
// in the next iteration of the loop
|
||||||
|
if prefill_tokens + postfix_len >= prefill_token_budget {
|
||||||
|
// Entry is over budget
|
||||||
|
if self.support_chunking {
|
||||||
|
// We support chunking, just set postfix_len to exactly match prefill_token_budget
|
||||||
|
postfix_len = prefill_token_budget - prefill_tokens;
|
||||||
|
// Push this entry inside the batch
|
||||||
|
batch.push((id, entry, Some(block_allocation), postfix_len));
|
||||||
|
break 'entry_loop;
|
||||||
|
} else {
|
||||||
|
// We don't support chunking, this entry needs to go back to the buffer
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!(
|
||||||
|
"Over budget: prefill_tokens={} > {prefill_token_budget}",
|
||||||
|
prefill_tokens + postfix_len
|
||||||
|
);
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break 'entry_loop;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prefill_tokens += postfix_len;
|
||||||
|
|
||||||
|
(Some(block_allocation), postfix_len)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
batch.push((id, entry, block_allocation));
|
batch.push((id, entry, block_allocation, postfix_len));
|
||||||
if Some(batch.len()) == max_size {
|
if Some(batch.len()) == max_size {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -342,7 +380,7 @@ impl State {
|
|||||||
// Batch is too small
|
// Batch is too small
|
||||||
if batch.len() < min_size {
|
if batch.len() < min_size {
|
||||||
// Add back entries to the queue in the correct order
|
// Add back entries to the queue in the correct order
|
||||||
for (id, entry, _) in batch.into_iter().rev() {
|
for (id, entry, _, _) in batch.into_iter().rev() {
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
}
|
}
|
||||||
return None;
|
return None;
|
||||||
@ -353,29 +391,7 @@ impl State {
|
|||||||
let mut batch_entries =
|
let mut batch_entries =
|
||||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||||
|
|
||||||
for (id, mut entry, block_allocation) in batch {
|
for (id, mut entry, block_allocation, postfix_len) in batch {
|
||||||
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
|
||||||
(block_allocation, &self.block_allocator)
|
|
||||||
{
|
|
||||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
|
||||||
match block_allocator.allocate(tokens, input_ids).await {
|
|
||||||
None => {
|
|
||||||
// Entry is over budget
|
|
||||||
// Add it back to the front
|
|
||||||
tracing::debug!("Over budget: not enough free blocks");
|
|
||||||
self.entries.push_front((id, entry));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
Some(block_allocation) => {
|
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
|
||||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
|
||||||
Some(block_allocation)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
tracing::debug!("Accepting entry");
|
|
||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||||
// Add relationships
|
// Add relationships
|
||||||
@ -429,6 +445,7 @@ impl State {
|
|||||||
slots,
|
slots,
|
||||||
prefix_len,
|
prefix_len,
|
||||||
adapter_id: entry.request.adapter_id.clone(),
|
adapter_id: entry.request.adapter_id.clone(),
|
||||||
|
postfix_len,
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
@ -436,12 +453,6 @@ impl State {
|
|||||||
batch_entries.insert(id, entry);
|
batch_entries.insert(id, entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty batch
|
|
||||||
if batch_requests.is_empty() {
|
|
||||||
tracing::debug!("Filterered out all entries");
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final batch size
|
// Final batch size
|
||||||
let size = batch_requests.len() as u32;
|
let size = batch_requests.len() as u32;
|
||||||
next_batch_span.record("batch_size", size);
|
next_batch_span.record("batch_size", size);
|
||||||
|
@ -159,6 +159,7 @@ async fn prefill(
|
|||||||
blocks: vec![],
|
blocks: vec![],
|
||||||
slots: vec![],
|
slots: vec![],
|
||||||
prefix_len: 0,
|
prefix_len: 0,
|
||||||
|
postfix_len: sequence_length,
|
||||||
adapter_id: None,
|
adapter_id: None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
@ -34,6 +34,7 @@ message InfoResponse {
|
|||||||
string device_type = 3;
|
string device_type = 3;
|
||||||
optional uint32 window_size = 4;
|
optional uint32 window_size = 4;
|
||||||
uint32 speculate = 5;
|
uint32 speculate = 5;
|
||||||
|
bool support_chunking = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Empty request
|
/// Empty request
|
||||||
@ -139,6 +140,8 @@ message Request {
|
|||||||
uint32 prefix_len = 12;
|
uint32 prefix_len = 12;
|
||||||
/// Context truncation
|
/// Context truncation
|
||||||
bool add_special_tokens = 13;
|
bool add_special_tokens = 13;
|
||||||
|
/// Postfix length for prefill chunking
|
||||||
|
uint32 postfix_len = 14;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
@ -163,6 +166,8 @@ message CachedBatch {
|
|||||||
uint32 size = 3;
|
uint32 size = 3;
|
||||||
/// Maximum number of tokens this batch will grow to
|
/// Maximum number of tokens this batch will grow to
|
||||||
uint32 max_tokens = 4;
|
uint32 max_tokens = 4;
|
||||||
|
/// Number of tokens in the next forward
|
||||||
|
uint32 current_tokens = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum FinishReason {
|
enum FinishReason {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import os
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_pb_parameters():
|
def default_pb_parameters():
|
||||||
return generate_pb2.NextTokenChooserParameters(
|
return generate_pb2.NextTokenChooserParameters(
|
||||||
|
@ -76,6 +76,7 @@ class CausalLMBatch(Batch):
|
|||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
current_tokens=len(self),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -16,7 +16,17 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
)
|
)
|
||||||
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
ContextManager,
|
||||||
|
Iterable,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
List,
|
||||||
|
Type,
|
||||||
|
Dict,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
@ -24,6 +34,10 @@ from text_generation_server.utils.chunks import concat_text_chunks
|
|||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
from text_generation_server.utils.prefill_chunking import (
|
||||||
|
get_support_chunking,
|
||||||
|
get_max_prefill_tokens,
|
||||||
|
)
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
@ -60,12 +74,9 @@ from text_generation_server.utils.import_utils import (
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Will be set in init
|
# Will be set in init
|
||||||
SLIDING_WINDOW: Optional[int] = None
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
|
|
||||||
TOKEN_BUDGET = 8
|
|
||||||
|
|
||||||
|
|
||||||
def set_sliding_window(sliding_window: int):
|
def set_sliding_window(sliding_window: int):
|
||||||
global SLIDING_WINDOW
|
global SLIDING_WINDOW
|
||||||
@ -206,6 +217,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.num_blocks * BLOCK_SIZE,
|
max_tokens=self.num_blocks * BLOCK_SIZE,
|
||||||
|
current_tokens=(
|
||||||
|
sum([len(i) for i in self.input_ids])
|
||||||
|
if isinstance(self.input_ids, list)
|
||||||
|
else len(self.input_ids)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -272,7 +288,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prompt_lengths.append(prompt_length)
|
prompt_lengths.append(prompt_length)
|
||||||
|
|
||||||
prefix_length = r.prefix_len
|
prefix_length = r.prefix_len
|
||||||
postfix_length = prefix_length + 10
|
postfix_length = r.postfix_len
|
||||||
assert (
|
assert (
|
||||||
prefix_length <= prompt_length
|
prefix_length <= prompt_length
|
||||||
), f"Prefix {prefix_length} vs input {prompt_length}"
|
), f"Prefix {prefix_length} vs input {prompt_length}"
|
||||||
@ -282,10 +298,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
if prefix_length + postfix_length < prompt_length:
|
if prefix_length + postfix_length < prompt_length:
|
||||||
# FIXME: speculate is not supported for context chunking at the moment
|
# FIXME: speculate is not supported for context chunking at the moment
|
||||||
assert speculate == 0
|
assert speculate == 0
|
||||||
|
assert get_support_chunking()
|
||||||
|
assert postfix_length > 0
|
||||||
|
|
||||||
prefix_ids.append(tokenized_input[:prefix_length])
|
prefix_ids.append(tokenized_input[:prefix_length])
|
||||||
postfix_ids = tokenized_input[prefix_length : postfix_length]
|
postfix_ids = tokenized_input[
|
||||||
# postfix_ids = tokenized_input[prefix_length:]
|
prefix_length : prefix_length + postfix_length
|
||||||
|
]
|
||||||
|
|
||||||
postfix_length = len(postfix_ids)
|
postfix_length = len(postfix_ids)
|
||||||
postfix_lengths.append(postfix_length)
|
postfix_lengths.append(postfix_length)
|
||||||
@ -371,7 +390,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
requests_idx_mapping=requests_idx_mapping,
|
||||||
input_ids=all_postfix_ids,
|
input_ids=all_postfix_ids,
|
||||||
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
prefix_lengths=prefix_lengths,
|
prefix_lengths=prefix_lengths,
|
||||||
@ -395,7 +413,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=None,
|
speculative_ids=None,
|
||||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||||
|
|
||||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
@ -431,7 +448,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
if len(request_ids) == len(self):
|
if len(request_ids) == len(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
device = self.input_ids.device
|
device = self.block_tables_tensor.device
|
||||||
|
|
||||||
# New values after filtering
|
# New values after filtering
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
@ -552,13 +569,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
if self.prefilling:
|
if self.prefilling:
|
||||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||||
position_ids=None
|
position_ids = None
|
||||||
start_slots=None
|
start_slots = None
|
||||||
slot_indices=None
|
slot_indices = None
|
||||||
slots=None
|
slots = None
|
||||||
prefix_lengths_tensor=None
|
prefix_lengths_tensor = None
|
||||||
postfix_lengths_tensor=None
|
postfix_lengths_tensor = None
|
||||||
adapter_meta=None
|
adapter_meta = None
|
||||||
else:
|
else:
|
||||||
# Index into tensors
|
# Index into tensors
|
||||||
input_ids = self.input_ids[indices]
|
input_ids = self.input_ids[indices]
|
||||||
@ -643,24 +660,24 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_current_length = 0
|
max_current_length = 0
|
||||||
for b in batches:
|
for b in batches:
|
||||||
total_batch_size += len(b)
|
total_batch_size += len(b)
|
||||||
total_slots += len(b.slots)
|
max_blocks = max(max_blocks, b.max_blocks)
|
||||||
|
# If `b` is prefilling and was just filtered, `b.slots` is None
|
||||||
|
# `total_slots` is not used if any of the batches is prefilling
|
||||||
|
total_slots += len(b.slots) if not b.prefilling else 0
|
||||||
num_blocks += b.num_blocks
|
num_blocks += b.num_blocks
|
||||||
speculative_length = (
|
speculative_length = (
|
||||||
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||||
)
|
)
|
||||||
max_blocks = max(max_blocks, b.max_blocks)
|
|
||||||
max_postfix_length = max(max_postfix_length, b.max_postfix_length)
|
max_postfix_length = max(max_postfix_length, b.max_postfix_length)
|
||||||
max_current_length = max(max_current_length, b.max_current_length)
|
max_current_length = max(max_current_length, b.max_current_length)
|
||||||
max_length = max(
|
max_length = max(
|
||||||
max_length,
|
max_length,
|
||||||
max(
|
max(
|
||||||
prefix_length
|
prompt_length
|
||||||
+ postfix_length
|
|
||||||
+ stopping_criteria.max_new_tokens
|
+ stopping_criteria.max_new_tokens
|
||||||
+ speculative_length
|
+ speculative_length
|
||||||
- stopping_criteria.current_tokens
|
for prompt_length, stopping_criteria in zip(
|
||||||
for prefix_length, postfix_length, stopping_criteria in zip(
|
b.prompt_lengths, b.stopping_criterias
|
||||||
b.prefix_lengths, b.postfix_lengths, b.stopping_criterias
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -669,14 +686,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
if prefilling:
|
if prefilling:
|
||||||
input_ids = []
|
input_ids = []
|
||||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||||
position_ids=None
|
position_ids = None
|
||||||
start_slots=None
|
start_slots = None
|
||||||
slots=None
|
slots = None
|
||||||
slot_indices=None
|
slot_indices = None
|
||||||
prefix_lengths_tensor=None
|
prefix_lengths_tensor = None
|
||||||
postfix_lengths_tensor=None
|
postfix_lengths_tensor = None
|
||||||
adapter_meta=None
|
adapter_meta = None
|
||||||
adapter_segment_builder=None
|
adapter_segment_builder = None
|
||||||
else:
|
else:
|
||||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||||
@ -746,8 +763,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
start_index = cumulative_batch_size
|
start_index = cumulative_batch_size
|
||||||
end_index = cumulative_batch_size + len(batch)
|
end_index = cumulative_batch_size + len(batch)
|
||||||
slots_start_index = cumulative_slots
|
|
||||||
slots_end_index = cumulative_slots + len(batch.slots)
|
|
||||||
|
|
||||||
# Copy tensors (GPU)
|
# Copy tensors (GPU)
|
||||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||||
@ -761,10 +776,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
||||||
|
|
||||||
if not prefilling:
|
if not prefilling:
|
||||||
|
slots_start_index = cumulative_slots
|
||||||
|
slots_end_index = cumulative_slots + len(batch.slots)
|
||||||
|
|
||||||
input_ids[start_index:end_index] = batch.input_ids
|
input_ids[start_index:end_index] = batch.input_ids
|
||||||
position_ids[start_index:end_index] = batch.position_ids
|
position_ids[start_index:end_index] = batch.position_ids
|
||||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
slot_indices[start_index:end_index] = (
|
||||||
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
|
batch.slot_indices + cumulative_slots
|
||||||
|
)
|
||||||
|
postfix_lengths_tensor[start_index:end_index] = (
|
||||||
|
batch.postfix_lengths_tensor
|
||||||
|
)
|
||||||
slots[slots_start_index:slots_end_index] = batch.slots
|
slots[slots_start_index:slots_end_index] = batch.slots
|
||||||
|
|
||||||
# Copy over adapter indices
|
# Copy over adapter indices
|
||||||
@ -779,11 +801,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cumulative_adapter_indices_size = adapter_end_index
|
cumulative_adapter_indices_size = adapter_end_index
|
||||||
adapter_set.update(batch.adapter_meta.adapter_set)
|
adapter_set.update(batch.adapter_meta.adapter_set)
|
||||||
adapter_segment_builder.concat(
|
adapter_segment_builder.concat(
|
||||||
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
|
batch.adapter_meta.adapter_segments,
|
||||||
|
batch.adapter_meta.segment_indices,
|
||||||
|
)
|
||||||
|
prefix_lengths_tensor[start_index:end_index] = (
|
||||||
|
batch.prefix_lengths_tensor
|
||||||
)
|
)
|
||||||
prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor
|
|
||||||
|
|
||||||
start_slots.append(batch.start_slots + cumulative_slots)
|
start_slots.append(batch.start_slots + cumulative_slots)
|
||||||
|
|
||||||
|
# Update
|
||||||
|
cumulative_slots += len(batch.slots)
|
||||||
else:
|
else:
|
||||||
if isinstance(batch.input_ids, torch.Tensor):
|
if isinstance(batch.input_ids, torch.Tensor):
|
||||||
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
|
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
|
||||||
@ -810,7 +838,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
cumulative_slots += len(batch.slots)
|
|
||||||
|
|
||||||
if start_slots is not None:
|
if start_slots is not None:
|
||||||
start_slots = torch.concat(start_slots)
|
start_slots = torch.concat(start_slots)
|
||||||
@ -915,7 +942,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
postfix_length,
|
postfix_length,
|
||||||
prompt_length,
|
prompt_length,
|
||||||
request_prefilling,
|
request_prefilling,
|
||||||
blocks
|
blocks,
|
||||||
) in enumerate(
|
) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
self.requests,
|
self.requests,
|
||||||
@ -923,7 +950,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
self.postfix_lengths,
|
self.postfix_lengths,
|
||||||
self.prompt_lengths,
|
self.prompt_lengths,
|
||||||
self.prefilling_mask,
|
self.prefilling_mask,
|
||||||
self.block_tables
|
self.block_tables,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
next_chunk_length = postfix_length
|
next_chunk_length = postfix_length
|
||||||
@ -967,9 +994,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
|
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
|
||||||
|
|
||||||
if prefill_logprobs:
|
if prefill_logprobs:
|
||||||
prefill_head_indices.append(
|
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||||
request_position_ids + cumulative_length
|
|
||||||
)
|
|
||||||
prefill_next_token_indices.append(
|
prefill_next_token_indices.append(
|
||||||
prefill_out_cumulative_length + postfix_length - 1
|
prefill_out_cumulative_length + postfix_length - 1
|
||||||
)
|
)
|
||||||
@ -988,7 +1013,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||||
prefill_out_cumulative_length += 1
|
prefill_out_cumulative_length += 1
|
||||||
|
|
||||||
|
|
||||||
start_slots.append(cumulative_slot_tokens)
|
start_slots.append(cumulative_slot_tokens)
|
||||||
slots.extend(request_slots)
|
slots.extend(request_slots)
|
||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
@ -998,9 +1022,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||||
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||||
adapter_indices_list.append(
|
adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index))
|
||||||
torch.full((next_chunk_length,), adapter_index)
|
|
||||||
)
|
|
||||||
adapter_set.add(adapter_index)
|
adapter_set.add(adapter_index)
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
@ -1240,6 +1262,7 @@ class FlashCausalLM(Model):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
sliding_window=config.sliding_window,
|
sliding_window=config.sliding_window,
|
||||||
|
support_chunking=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1764,29 +1787,43 @@ class FlashCausalLM(Model):
|
|||||||
finished_prefilling = True
|
finished_prefilling = True
|
||||||
next_chunk_lengths = []
|
next_chunk_lengths = []
|
||||||
if prefill:
|
if prefill:
|
||||||
next_prefilling_mask = []
|
if get_support_chunking():
|
||||||
# Budget in tokens for the next batch
|
next_prefilling_mask = []
|
||||||
# We remove next input ids to always have enough space for at least a single decode
|
# Budget in tokens for the next batch
|
||||||
# for the remaining requests
|
# We remove len(batch) to always have enough space for at least a single decode
|
||||||
batch_budget = TOKEN_BUDGET - len(batch)
|
# for the remaining requests
|
||||||
for prefix_length, postfix_length, prompt_length in zip(
|
batch_budget = get_max_prefill_tokens() - len(batch)
|
||||||
batch.prefix_lengths, batch.postfix_lengths, batch.prompt_lengths
|
# We reverse to prioritize older requests
|
||||||
):
|
# zip() is not reversible so reverse the underlying lists instead
|
||||||
remaining_prefill_tokens = max(
|
for prefix_length, postfix_length, prompt_length in zip(
|
||||||
prompt_length - prefix_length - postfix_length, 0
|
reversed(batch.prefix_lengths),
|
||||||
)
|
reversed(batch.postfix_lengths),
|
||||||
if remaining_prefill_tokens > 0:
|
reversed(batch.prompt_lengths),
|
||||||
next_chunk_length = max(
|
):
|
||||||
min(remaining_prefill_tokens, batch_budget), 1
|
remaining_prefill_tokens = max(
|
||||||
|
prompt_length - prefix_length - postfix_length, 0
|
||||||
)
|
)
|
||||||
batch_budget -= next_chunk_length
|
if remaining_prefill_tokens > 0:
|
||||||
finished_prefilling = False
|
next_chunk_length = max(
|
||||||
next_prefilling_mask.append(True)
|
min(remaining_prefill_tokens, batch_budget), 1
|
||||||
else:
|
)
|
||||||
# Since speculation will be turned off, this is always true
|
batch_budget -= next_chunk_length
|
||||||
next_chunk_length = 1
|
finished_prefilling = False
|
||||||
next_prefilling_mask.append(False)
|
next_prefilling_mask.append(True)
|
||||||
next_chunk_lengths.append(next_chunk_length)
|
else:
|
||||||
|
# Since speculation will be turned off, this is always true
|
||||||
|
next_chunk_length = 1
|
||||||
|
next_prefilling_mask.append(False)
|
||||||
|
next_chunk_lengths.append(next_chunk_length)
|
||||||
|
|
||||||
|
# Reverse back the obtained values²
|
||||||
|
next_chunk_lengths.reverse()
|
||||||
|
next_prefilling_mask.reverse()
|
||||||
|
else:
|
||||||
|
# The model does not support chunking
|
||||||
|
# We know we only do a single prefill
|
||||||
|
finished_prefilling = True
|
||||||
|
next_prefilling_mask = [False] * len(batch)
|
||||||
|
|
||||||
batch.prefilling = not finished_prefilling
|
batch.prefilling = not finished_prefilling
|
||||||
batch.prefilling_mask = next_prefilling_mask
|
batch.prefilling_mask = next_prefilling_mask
|
||||||
@ -2179,7 +2216,9 @@ class FlashCausalLM(Model):
|
|||||||
# have more than one new token per request with speculative decoding
|
# have more than one new token per request with speculative decoding
|
||||||
for next_token_id in _next_token_ids:
|
for next_token_id in _next_token_ids:
|
||||||
batch.next_token_chooser = (
|
batch.next_token_chooser = (
|
||||||
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
|
batch.next_token_chooser.advance_grammar_single(
|
||||||
|
i, next_token_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
|
@ -18,7 +18,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
|||||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
|
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
|
||||||
assert TGI_WIGGLE_ROOM > 0
|
assert TGI_WIGGLE_ROOM > 0
|
||||||
assert TGI_WIGGLE_ROOM < 1
|
assert TGI_WIGGLE_ROOM < 1
|
||||||
|
|
||||||
|
@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
current_tokens=len(self),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -116,6 +116,7 @@ class MambaBatch(Batch):
|
|||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
current_tokens=len(self),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -5,8 +5,11 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict
|
from typing import List, Tuple, Optional, TypeVar, Type, Dict
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.models.types import Batch, Generation
|
from text_generation_server.models.types import Batch, Generation
|
||||||
|
from text_generation_server.utils.log import log_master
|
||||||
|
from text_generation_server.utils.prefill_chunking import set_support_chunking
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||||
@ -31,6 +34,7 @@ class Model(ABC):
|
|||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
speculate: Optional[int] = None,
|
speculate: Optional[int] = None,
|
||||||
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
||||||
|
support_chunking: bool = False,
|
||||||
):
|
):
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.model = model.eval()
|
self.model = model.eval()
|
||||||
@ -60,6 +64,17 @@ class Model(ABC):
|
|||||||
speculate = get_speculate()
|
speculate = get_speculate()
|
||||||
self.speculate = speculate
|
self.speculate = speculate
|
||||||
|
|
||||||
|
if speculate != 0 and support_chunking:
|
||||||
|
log_master(
|
||||||
|
logger.warning,
|
||||||
|
"Prefill chunking does not support speculation yet. "
|
||||||
|
"Prefill chunking will be turned off",
|
||||||
|
)
|
||||||
|
support_chunking = False
|
||||||
|
|
||||||
|
self.support_chunking = support_chunking
|
||||||
|
set_support_chunking(support_chunking)
|
||||||
|
|
||||||
self.has_position_ids = (
|
self.has_position_ids = (
|
||||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||||
is not None
|
is not None
|
||||||
@ -78,6 +93,7 @@ class Model(ABC):
|
|||||||
device_type=self.device.type,
|
device_type=self.device.type,
|
||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
speculate=self.speculate,
|
speculate=self.speculate,
|
||||||
|
support_chunking=self.support_chunking,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch):
|
|||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
current_tokens=len(self),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -357,7 +357,6 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
else:
|
else:
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
input_lengths = postfix_lengths + prefix_lengths_tensor
|
|
||||||
if PREFIX_CACHING:
|
if PREFIX_CACHING:
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
@ -424,7 +423,7 @@ class VlmCausalLM(FlashCausalLM):
|
|||||||
cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths
|
cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths
|
||||||
cuda_graph["prefix_lengths"].zero_()
|
cuda_graph["prefix_lengths"].zero_()
|
||||||
cuda_graph["prefix_lengths"][
|
cuda_graph["prefix_lengths"][
|
||||||
: prefix_lengths_tensor.shape[0]
|
: prefix_lengths_tensor.shape[0]
|
||||||
] = prefix_lengths_tensor
|
] = prefix_lengths_tensor
|
||||||
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
|
@ -15,6 +15,7 @@ from text_generation_server.cache import Cache
|
|||||||
from text_generation_server.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation_server.models import Model, get_model_with_lora_adapters
|
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||||
from text_generation_server.utils.adapter import AdapterInfo
|
from text_generation_server.utils.adapter import AdapterInfo
|
||||||
|
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||||
@ -96,6 +97,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
|
set_max_prefill_tokens(request.max_prefill_tokens)
|
||||||
|
|
||||||
if self.quantize in {"exl2", "gptq"}:
|
if self.quantize in {"exl2", "gptq"}:
|
||||||
try:
|
try:
|
||||||
# When using GPTQ, Exllama kernels need some global kernels
|
# When using GPTQ, Exllama kernels need some global kernels
|
||||||
|
24
server/text_generation_server/utils/prefill_chunking.py
Normal file
24
server/text_generation_server/utils/prefill_chunking.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
SUPPORT_CHUNKING: Optional[bool] = None
|
||||||
|
MAX_PREFILL_TOKENS: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_support_chunking(support_chunking: bool):
|
||||||
|
global SUPPORT_CHUNKING
|
||||||
|
SUPPORT_CHUNKING = support_chunking
|
||||||
|
|
||||||
|
|
||||||
|
def get_support_chunking() -> bool:
|
||||||
|
global SUPPORT_CHUNKING
|
||||||
|
return SUPPORT_CHUNKING
|
||||||
|
|
||||||
|
|
||||||
|
def set_max_prefill_tokens(max_prefill_tokens: int):
|
||||||
|
global MAX_PREFILL_TOKENS
|
||||||
|
MAX_PREFILL_TOKENS = max_prefill_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_prefill_tokens() -> int:
|
||||||
|
global MAX_PREFILL_TOKENS
|
||||||
|
return MAX_PREFILL_TOKENS
|
Loading…
Reference in New Issue
Block a user