load tested

This commit is contained in:
OlivierDehaene 2024-10-02 12:59:44 +02:00
parent 34f5dc525e
commit 7f9abde3f8
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
22 changed files with 307 additions and 195 deletions

View File

@ -159,6 +159,7 @@ impl Client {
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0, prefix_len: 0,
postfix_len: truncate,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,

View File

@ -246,6 +246,7 @@ impl Health for ShardedClient {
blocks: vec![0], blocks: vec![0],
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0, prefix_len: 0,
postfix_len: 1,
adapter_id: None, adapter_id: None,
}; };
let batch = Batch { let batch = Batch {

View File

@ -34,9 +34,13 @@ impl BackendV3 {
requires_padding: bool, requires_padding: bool,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
support_chunking: bool,
) -> Self { ) -> Self {
let prefix_caching = if support_chunking {
std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string()); tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
}
let prefix_caching = std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string()); let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string());
@ -52,6 +56,7 @@ impl BackendV3 {
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
support_chunking,
); );
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());
@ -63,6 +68,7 @@ impl BackendV3 {
max_batch_total_tokens, max_batch_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
support_chunking,
queue.clone(), queue.clone(),
batching_task_notifier.clone(), batching_task_notifier.clone(),
)); ));
@ -127,6 +133,7 @@ pub(crate) async fn batching_task(
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
max_waiting_tokens: usize, max_waiting_tokens: usize,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
support_chunking: bool,
queue: Queue, queue: Queue,
notifier: Arc<Notify>, notifier: Arc<Notify>,
) { ) {
@ -158,10 +165,24 @@ pub(crate) async fn batching_task(
// Get current batch info // Get current batch info
let batch_size = batch.size; let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens; let batch_max_tokens = batch.max_tokens;
let current_tokens = batch.current_tokens;
let mut batches = vec![batch]; let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let (min_size, max_size, prefill_token_budget) = if support_chunking {
// Since the next batch will be concatenated with the current batch,
// the current batch tokens must be subtracted to the prefill budget
// In the future, we could concatenate beforehand
let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
// We can ignore min_size and max_size
// Models than rely on max_size cannot support chunking
// Regarding min_size, chunking allow us to consistently run at the compute
// bound, making min_size useless.
(None, None, prefill_token_budget)
} else {
let min_size = if waiting_tokens >= max_waiting_tokens { let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try // If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small // to add a new batch even though its size might be small
@ -173,13 +194,15 @@ pub(crate) async fn batching_task(
Some((batch_size as f32 * waiting_served_ratio).floor() as usize) Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
}; };
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size = let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
(min_size, max_size, max_batch_prefill_tokens)
};
// Try to get a new batch // Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) .next_batch(min_size, max_size, prefill_token_budget, token_budget)
.await .await
{ {
// Tracking metrics // Tracking metrics

View File

@ -159,6 +159,7 @@ impl Client {
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0, prefix_len: 0,
postfix_len: truncate,
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,

View File

@ -29,15 +29,6 @@ pub trait Health {
async fn model_health(&self) -> Result<()>; async fn model_health(&self) -> Result<()>;
} }
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)] #[derive(Error, Debug, Clone)]
pub enum ClientError { pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")] #[error("Could not connect to Text Generation server: {0}")]

View File

@ -1,6 +1,6 @@
use crate::client::{ClientError, Result}; use crate::client::Health;
/// Multi shard Client /// Multi shard Client
use crate::client::{Health, ShardInfo}; use crate::client::{ClientError, Result};
use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::{ use crate::client::{
@ -49,13 +49,13 @@ impl ShardedClient {
/// Get the model info /// Get the model info
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> { pub async fn info(&mut self) -> Result<InfoResponse> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
.map(|client| client.info()) .map(|client| client.info())
.collect(); .collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from) join_all(futures).await.pop().unwrap()
} }
/// GRPC health check /// GRPC health check
@ -194,18 +194,6 @@ impl ShardedClient {
} }
} }
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait] #[async_trait]
impl Health for ShardedClient { impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> { async fn device_health(&self) -> Result<()> {
@ -248,6 +236,7 @@ impl Health for ShardedClient {
slots: (0..16).collect(), slots: (0..16).collect(),
prefix_len: 0, prefix_len: 0,
adapter_id: None, adapter_id: None,
postfix_len: 1,
}; };
let batch = Batch { let batch = Batch {
id: u64::MAX, id: u64::MAX,

View File

@ -29,6 +29,8 @@ pub struct BackendInfo {
pub max_waiting_tokens: usize, pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>, pub max_batch_size: Option<usize>,
#[schema(example = "false")]
pub support_chunking: bool,
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
@ -110,6 +112,7 @@ pub async fn connect_backend(
model_device_type: shard_info.device_type.clone(), model_device_type: shard_info.device_type.clone(),
model_dtype: shard_info.dtype.clone(), model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize, speculate: shard_info.speculate as usize,
support_chunking: shard_info.support_chunking,
}; };
let backend = BackendV3::new( let backend = BackendV3::new(
@ -122,6 +125,7 @@ pub async fn connect_backend(
shard_info.requires_padding, shard_info.requires_padding,
shard_info.window_size, shard_info.window_size,
shard_info.speculate, shard_info.speculate,
shard_info.support_chunking,
); );
tracing::info!("Using backend V3"); tracing::info!("Using backend V3");

View File

@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> {
"`max_input_tokens` must be < `max_total_tokens`".to_string(), "`max_input_tokens` must be < `max_total_tokens`".to_string(),
)); ));
} }
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 { if validation_workers == 0 {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(), "`validation_workers` must be > 0".to_string(),
)); ));
} }
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
if let Some(max_batch_size) = max_batch_size { if let Some(max_batch_size) = max_batch_size {
if max_batch_size == 0 { if max_batch_size == 0 {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> {
} }
} }
let (backend, _backend_info) = connect_backend( let (backend, backend_info) = connect_backend(
max_input_tokens, max_input_tokens,
max_total_tokens, max_total_tokens,
master_shard_uds_path, master_shard_uds_path,
@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> {
) )
.await?; .await?;
// Validate remaining args now that the backend is known
let support_chunking = backend_info.support_chunking;
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if max_batch_prefill_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
// Run server // Run server
server::run( server::run(
backend, backend,

View File

@ -4,7 +4,7 @@ use crate::client::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
}; };
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min}; use std::cmp::max;
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_router::infer::InferError; use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse; use text_generation_router::infer::InferStreamResponse;
@ -50,6 +50,7 @@ impl Queue {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
support_chunking: bool,
) -> Self { ) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
@ -62,6 +63,7 @@ impl Queue {
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
support_chunking,
queue_receiver, queue_receiver,
)); ));
@ -108,6 +110,7 @@ impl Queue {
} }
// Background task responsible of the queue state // Background task responsible of the queue state
#[allow(clippy::too_many_arguments)]
async fn queue_task( async fn queue_task(
requires_padding: bool, requires_padding: bool,
block_size: u32, block_size: u32,
@ -115,6 +118,7 @@ async fn queue_task(
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
support_chunking: bool,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>, mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) { ) {
let mut state = State::new( let mut state = State::new(
@ -124,6 +128,7 @@ async fn queue_task(
window_size, window_size,
speculate, speculate,
max_batch_total_tokens, max_batch_total_tokens,
support_chunking,
); );
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
@ -166,12 +171,14 @@ struct State {
/// Paged Attention block size /// Paged Attention block size
block_size: u32, block_size: u32,
/// Sliding window
window_size: Option<u32>,
/// Speculation amount /// Speculation amount
speculate: u32, speculate: u32,
/// Whether the model allow the prefill chunking
/// If it does, the last request in the batch will be split to exactly match the prefill
/// token budget
support_chunking: bool,
/// Paged Attention Block Allocation /// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>, block_allocator: Option<BlockAllocator>,
} }
@ -184,6 +191,7 @@ impl State {
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32, max_batch_total_tokens: u32,
support_chunking: bool,
) -> Self { ) -> Self {
let block_allocator = (!requires_padding).then(|| { let block_allocator = (!requires_padding).then(|| {
BlockAllocator::new( BlockAllocator::new(
@ -199,8 +207,8 @@ impl State {
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
block_size, block_size,
window_size,
speculate, speculate,
support_chunking,
block_allocator, block_allocator,
} }
} }
@ -268,7 +276,7 @@ impl State {
continue; continue;
} }
let block_allocation = match &self.block_allocator { let (block_allocation, postfix_len) = match &self.block_allocator {
None => { None => {
// We pad to max input length in the Python shards // We pad to max input length in the Python shards
// We need to take these padding tokens into the equation // We need to take these padding tokens into the equation
@ -285,34 +293,9 @@ impl State {
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
break 'entry_loop; break 'entry_loop;
} }
None (None, entry.request.input_length)
} }
Some(_block_allocator) => { Some(block_allocator) => {
prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break;
}
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
// If users wants the prefill logprobs, we cannot reuse the cache. // If users wants the prefill logprobs, we cannot reuse the cache.
// So no input_ids for the radix tree. // So no input_ids for the radix tree.
let input_ids = if entry.request.decoder_input_details { let input_ids = if entry.request.decoder_input_details {
@ -321,10 +304,65 @@ impl State {
entry.request.input_ids.clone() entry.request.input_ids.clone()
}; };
Some((tokens, input_ids)) let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
tracing::debug!("Allocating {tokens} with {input_ids:?}");
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(mut block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
if block_allocation.prefix_len == entry.request.input_length {
// The whole request was found in the radix trie
// However, for the transformer forward to work, we need to
// have at least one token of postfix.
block_allocation.prefix_len -= 1;
}
block_allocation
} }
}; };
batch.push((id, entry, block_allocation));
let mut postfix_len = entry.request.input_length - block_allocation.prefix_len;
// Check equality too as if we don't we might end up with a postfix_len = 0
// in the next iteration of the loop
if prefill_tokens + postfix_len >= prefill_token_budget {
// Entry is over budget
if self.support_chunking {
// We support chunking, just set postfix_len to exactly match prefill_token_budget
postfix_len = prefill_token_budget - prefill_tokens;
// Push this entry inside the batch
batch.push((id, entry, Some(block_allocation), postfix_len));
break 'entry_loop;
} else {
// We don't support chunking, this entry needs to go back to the buffer
// Add it back to the front
tracing::debug!(
"Over budget: prefill_tokens={} > {prefill_token_budget}",
prefill_tokens + postfix_len
);
self.entries.push_front((id, entry));
break 'entry_loop;
}
}
prefill_tokens += postfix_len;
(Some(block_allocation), postfix_len)
}
};
batch.push((id, entry, block_allocation, postfix_len));
if Some(batch.len()) == max_size { if Some(batch.len()) == max_size {
break; break;
} }
@ -342,7 +380,7 @@ impl State {
// Batch is too small // Batch is too small
if batch.len() < min_size { if batch.len() < min_size {
// Add back entries to the queue in the correct order // Add back entries to the queue in the correct order
for (id, entry, _) in batch.into_iter().rev() { for (id, entry, _, _) in batch.into_iter().rev() {
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
} }
return None; return None;
@ -353,29 +391,7 @@ impl State {
let mut batch_entries = let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
for (id, mut entry, block_allocation) in batch { for (id, mut entry, block_allocation, postfix_len) in batch {
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
(block_allocation, &self.block_allocator)
{
tracing::debug!("Allocating {tokens} with {input_ids:?}");
match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
continue;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
}
}
} else {
None
};
tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer"); let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships // Add relationships
@ -429,6 +445,7 @@ impl State {
slots, slots,
prefix_len, prefix_len,
adapter_id: entry.request.adapter_id.clone(), adapter_id: entry.request.adapter_id.clone(),
postfix_len,
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
@ -436,12 +453,6 @@ impl State {
batch_entries.insert(id, entry); batch_entries.insert(id, entry);
} }
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// Final batch size // Final batch size
let size = batch_requests.len() as u32; let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size); next_batch_span.record("batch_size", size);

View File

@ -159,6 +159,7 @@ async fn prefill(
blocks: vec![], blocks: vec![],
slots: vec![], slots: vec![],
prefix_len: 0, prefix_len: 0,
postfix_len: sequence_length,
adapter_id: None, adapter_id: None,
}) })
.collect(); .collect();

View File

@ -34,6 +34,7 @@ message InfoResponse {
string device_type = 3; string device_type = 3;
optional uint32 window_size = 4; optional uint32 window_size = 4;
uint32 speculate = 5; uint32 speculate = 5;
bool support_chunking = 6;
} }
/// Empty request /// Empty request
@ -139,6 +140,8 @@ message Request {
uint32 prefix_len = 12; uint32 prefix_len = 12;
/// Context truncation /// Context truncation
bool add_special_tokens = 13; bool add_special_tokens = 13;
/// Postfix length for prefill chunking
uint32 postfix_len = 14;
} }
message Batch { message Batch {
@ -163,6 +166,8 @@ message CachedBatch {
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to /// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4; uint32 max_tokens = 4;
/// Number of tokens in the next forward
uint32 current_tokens = 5;
} }
enum FinishReason { enum FinishReason {

View File

@ -1,7 +1,7 @@
import pytest import pytest
import os
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
@pytest.fixture @pytest.fixture
def default_pb_parameters(): def default_pb_parameters():
return generate_pb2.NextTokenChooserParameters( return generate_pb2.NextTokenChooserParameters(

View File

@ -76,6 +76,7 @@ class CausalLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self),
) )
@classmethod @classmethod

View File

@ -16,7 +16,17 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
GenerationConfig, GenerationConfig,
) )
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict, Union from typing import (
Any,
ContextManager,
Iterable,
Optional,
Tuple,
List,
Type,
Dict,
Union,
)
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
@ -24,6 +34,10 @@ from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import (
get_support_chunking,
get_max_prefill_tokens,
)
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import ( from text_generation_server.utils import (
@ -60,12 +74,9 @@ from text_generation_server.utils.import_utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
# Will be set in init # Will be set in init
SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW: Optional[int] = None
TOKEN_BUDGET = 8
def set_sliding_window(sliding_window: int): def set_sliding_window(sliding_window: int):
global SLIDING_WINDOW global SLIDING_WINDOW
@ -206,6 +217,11 @@ class FlashCausalLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.num_blocks * BLOCK_SIZE, max_tokens=self.num_blocks * BLOCK_SIZE,
current_tokens=(
sum([len(i) for i in self.input_ids])
if isinstance(self.input_ids, list)
else len(self.input_ids)
),
) )
@classmethod @classmethod
@ -272,7 +288,7 @@ class FlashCausalLMBatch(Batch):
prompt_lengths.append(prompt_length) prompt_lengths.append(prompt_length)
prefix_length = r.prefix_len prefix_length = r.prefix_len
postfix_length = prefix_length + 10 postfix_length = r.postfix_len
assert ( assert (
prefix_length <= prompt_length prefix_length <= prompt_length
), f"Prefix {prefix_length} vs input {prompt_length}" ), f"Prefix {prefix_length} vs input {prompt_length}"
@ -282,10 +298,13 @@ class FlashCausalLMBatch(Batch):
if prefix_length + postfix_length < prompt_length: if prefix_length + postfix_length < prompt_length:
# FIXME: speculate is not supported for context chunking at the moment # FIXME: speculate is not supported for context chunking at the moment
assert speculate == 0 assert speculate == 0
assert get_support_chunking()
assert postfix_length > 0
prefix_ids.append(tokenized_input[:prefix_length]) prefix_ids.append(tokenized_input[:prefix_length])
postfix_ids = tokenized_input[prefix_length : postfix_length] postfix_ids = tokenized_input[
# postfix_ids = tokenized_input[prefix_length:] prefix_length : prefix_length + postfix_length
]
postfix_length = len(postfix_ids) postfix_length = len(postfix_ids)
postfix_lengths.append(postfix_length) postfix_lengths.append(postfix_length)
@ -371,7 +390,6 @@ class FlashCausalLMBatch(Batch):
requests=pb.requests, requests=pb.requests,
requests_idx_mapping=requests_idx_mapping, requests_idx_mapping=requests_idx_mapping,
input_ids=all_postfix_ids, input_ids=all_postfix_ids,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
prefix_lengths=prefix_lengths, prefix_lengths=prefix_lengths,
@ -395,7 +413,6 @@ class FlashCausalLMBatch(Batch):
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=None, speculative_ids=None,
prompt_lengths_tensor=prompt_lengths_tensor, prompt_lengths_tensor=prompt_lengths_tensor,
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill` # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids=None, position_ids=None,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
@ -431,7 +448,7 @@ class FlashCausalLMBatch(Batch):
if len(request_ids) == len(self): if len(request_ids) == len(self):
return self return self
device = self.input_ids.device device = self.block_tables_tensor.device
# New values after filtering # New values after filtering
requests_idx_mapping = {} requests_idx_mapping = {}
@ -643,24 +660,24 @@ class FlashCausalLMBatch(Batch):
max_current_length = 0 max_current_length = 0
for b in batches: for b in batches:
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots) max_blocks = max(max_blocks, b.max_blocks)
# If `b` is prefilling and was just filtered, `b.slots` is None
# `total_slots` is not used if any of the batches is prefilling
total_slots += len(b.slots) if not b.prefilling else 0
num_blocks += b.num_blocks num_blocks += b.num_blocks
speculative_length = ( speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
) )
max_blocks = max(max_blocks, b.max_blocks)
max_postfix_length = max(max_postfix_length, b.max_postfix_length) max_postfix_length = max(max_postfix_length, b.max_postfix_length)
max_current_length = max(max_current_length, b.max_current_length) max_current_length = max(max_current_length, b.max_current_length)
max_length = max( max_length = max(
max_length, max_length,
max( max(
prefix_length prompt_length
+ postfix_length
+ stopping_criteria.max_new_tokens + stopping_criteria.max_new_tokens
+ speculative_length + speculative_length
- stopping_criteria.current_tokens for prompt_length, stopping_criteria in zip(
for prefix_length, postfix_length, stopping_criteria in zip( b.prompt_lengths, b.stopping_criterias
b.prefix_lengths, b.postfix_lengths, b.stopping_criterias
) )
), ),
) )
@ -746,8 +763,6 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch) end_index = cumulative_batch_size + len(batch)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
# Copy tensors (GPU) # Copy tensors (GPU)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
@ -761,10 +776,17 @@ class FlashCausalLMBatch(Batch):
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
if not prefilling: if not prefilling:
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
input_ids[start_index:end_index] = batch.input_ids input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots slot_indices[start_index:end_index] = (
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor batch.slot_indices + cumulative_slots
)
postfix_lengths_tensor[start_index:end_index] = (
batch.postfix_lengths_tensor
)
slots[slots_start_index:slots_end_index] = batch.slots slots[slots_start_index:slots_end_index] = batch.slots
# Copy over adapter indices # Copy over adapter indices
@ -779,11 +801,17 @@ class FlashCausalLMBatch(Batch):
cumulative_adapter_indices_size = adapter_end_index cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set) adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat( adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices batch.adapter_meta.adapter_segments,
batch.adapter_meta.segment_indices,
)
prefix_lengths_tensor[start_index:end_index] = (
batch.prefix_lengths_tensor
) )
prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor
start_slots.append(batch.start_slots + cumulative_slots) start_slots.append(batch.start_slots + cumulative_slots)
# Update
cumulative_slots += len(batch.slots)
else: else:
if isinstance(batch.input_ids, torch.Tensor): if isinstance(batch.input_ids, torch.Tensor):
batch.input_ids = batch.input_ids.view(-1, 1).tolist() batch.input_ids = batch.input_ids.view(-1, 1).tolist()
@ -810,7 +838,6 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
cumulative_slots += len(batch.slots)
if start_slots is not None: if start_slots is not None:
start_slots = torch.concat(start_slots) start_slots = torch.concat(start_slots)
@ -915,7 +942,7 @@ class FlashCausalLMBatch(Batch):
postfix_length, postfix_length,
prompt_length, prompt_length,
request_prefilling, request_prefilling,
blocks blocks,
) in enumerate( ) in enumerate(
zip( zip(
self.requests, self.requests,
@ -923,7 +950,7 @@ class FlashCausalLMBatch(Batch):
self.postfix_lengths, self.postfix_lengths,
self.prompt_lengths, self.prompt_lengths,
self.prefilling_mask, self.prefilling_mask,
self.block_tables self.block_tables,
) )
): ):
next_chunk_length = postfix_length next_chunk_length = postfix_length
@ -967,9 +994,7 @@ class FlashCausalLMBatch(Batch):
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
if prefill_logprobs: if prefill_logprobs:
prefill_head_indices.append( prefill_head_indices.append(request_position_ids + cumulative_length)
request_position_ids + cumulative_length
)
prefill_next_token_indices.append( prefill_next_token_indices.append(
prefill_out_cumulative_length + postfix_length - 1 prefill_out_cumulative_length + postfix_length - 1
) )
@ -988,7 +1013,6 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1 prefill_out_cumulative_length += 1
start_slots.append(cumulative_slot_tokens) start_slots.append(cumulative_slot_tokens)
slots.extend(request_slots) slots.extend(request_slots)
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
@ -998,9 +1022,7 @@ class FlashCausalLMBatch(Batch):
ADAPTER_TO_INDEX = get_adapter_to_index() ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append( adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index))
torch.full((next_chunk_length,), adapter_index)
)
adapter_set.add(adapter_index) adapter_set.add(adapter_index)
# Update # Update
@ -1240,6 +1262,7 @@ class FlashCausalLM(Model):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
sliding_window=config.sliding_window, sliding_window=config.sliding_window,
support_chunking=True,
) )
@property @property
@ -1764,13 +1787,18 @@ class FlashCausalLM(Model):
finished_prefilling = True finished_prefilling = True
next_chunk_lengths = [] next_chunk_lengths = []
if prefill: if prefill:
if get_support_chunking():
next_prefilling_mask = [] next_prefilling_mask = []
# Budget in tokens for the next batch # Budget in tokens for the next batch
# We remove next input ids to always have enough space for at least a single decode # We remove len(batch) to always have enough space for at least a single decode
# for the remaining requests # for the remaining requests
batch_budget = TOKEN_BUDGET - len(batch) batch_budget = get_max_prefill_tokens() - len(batch)
# We reverse to prioritize older requests
# zip() is not reversible so reverse the underlying lists instead
for prefix_length, postfix_length, prompt_length in zip( for prefix_length, postfix_length, prompt_length in zip(
batch.prefix_lengths, batch.postfix_lengths, batch.prompt_lengths reversed(batch.prefix_lengths),
reversed(batch.postfix_lengths),
reversed(batch.prompt_lengths),
): ):
remaining_prefill_tokens = max( remaining_prefill_tokens = max(
prompt_length - prefix_length - postfix_length, 0 prompt_length - prefix_length - postfix_length, 0
@ -1788,6 +1816,15 @@ class FlashCausalLM(Model):
next_prefilling_mask.append(False) next_prefilling_mask.append(False)
next_chunk_lengths.append(next_chunk_length) next_chunk_lengths.append(next_chunk_length)
# Reverse back the obtained values²
next_chunk_lengths.reverse()
next_prefilling_mask.reverse()
else:
# The model does not support chunking
# We know we only do a single prefill
finished_prefilling = True
next_prefilling_mask = [False] * len(batch)
batch.prefilling = not finished_prefilling batch.prefilling = not finished_prefilling
batch.prefilling_mask = next_prefilling_mask batch.prefilling_mask = next_prefilling_mask
@ -2179,7 +2216,9 @@ class FlashCausalLM(Model):
# have more than one new token per request with speculative decoding # have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids: for next_token_id in _next_token_ids:
batch.next_token_chooser = ( batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(i, next_token_id) batch.next_token_chooser.advance_grammar_single(
i, next_token_id
)
) )
# Update values # Update values

View File

@ -18,7 +18,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer") raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1 assert TGI_WIGGLE_ROOM < 1

View File

@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self),
) )
@classmethod @classmethod

View File

@ -116,6 +116,7 @@ class MambaBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self),
) )
@classmethod @classmethod

View File

@ -5,8 +5,11 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict from collections import defaultdict
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from loguru import logger
from text_generation_server.models.types import Batch, Generation from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import set_support_chunking
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights from text_generation_server.adapters.weights import LayerAdapterWeights
@ -31,6 +34,7 @@ class Model(ABC):
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID, adapter_id: str = BASE_MODEL_ADAPTER_ID,
support_chunking: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.model = model.eval() self.model = model.eval()
@ -60,6 +64,17 @@ class Model(ABC):
speculate = get_speculate() speculate = get_speculate()
self.speculate = speculate self.speculate = speculate
if speculate != 0 and support_chunking:
log_master(
logger.warning,
"Prefill chunking does not support speculation yet. "
"Prefill chunking will be turned off",
)
support_chunking = False
self.support_chunking = support_chunking
set_support_chunking(support_chunking)
self.has_position_ids = ( self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None) inspect.signature(model.forward).parameters.get("position_ids", None)
is not None is not None
@ -78,6 +93,7 @@ class Model(ABC):
device_type=self.device.type, device_type=self.device.type,
window_size=self.sliding_window, window_size=self.sliding_window,
speculate=self.speculate, speculate=self.speculate,
support_chunking=self.support_chunking,
) )
@property @property

View File

@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch):
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
current_tokens=len(self),
) )
@classmethod @classmethod

View File

@ -357,7 +357,6 @@ class VlmCausalLM(FlashCausalLM):
else: else:
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = postfix_lengths + prefix_lengths_tensor
if PREFIX_CACHING: if PREFIX_CACHING:
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
block_tables=block_tables, block_tables=block_tables,

View File

@ -15,6 +15,7 @@ from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
try: try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch from text_generation_server.models.pali_gemma import PaliGemmaBatch
@ -96,6 +97,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
set_max_prefill_tokens(request.max_prefill_tokens)
if self.quantize in {"exl2", "gptq"}: if self.quantize in {"exl2", "gptq"}:
try: try:
# When using GPTQ, Exllama kernels need some global kernels # When using GPTQ, Exllama kernels need some global kernels

View File

@ -0,0 +1,24 @@
from typing import Optional
SUPPORT_CHUNKING: Optional[bool] = None
MAX_PREFILL_TOKENS: Optional[int] = None
def set_support_chunking(support_chunking: bool):
global SUPPORT_CHUNKING
SUPPORT_CHUNKING = support_chunking
def get_support_chunking() -> bool:
global SUPPORT_CHUNKING
return SUPPORT_CHUNKING
def set_max_prefill_tokens(max_prefill_tokens: int):
global MAX_PREFILL_TOKENS
MAX_PREFILL_TOKENS = max_prefill_tokens
def get_max_prefill_tokens() -> int:
global MAX_PREFILL_TOKENS
return MAX_PREFILL_TOKENS