From 7f9abde3f8a769acf1fb61c824b807e13bde80c1 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 2 Oct 2024 12:59:44 +0200 Subject: [PATCH] load tested --- backends/client/src/v3/client.rs | 1 + backends/client/src/v3/sharded_client.rs | 1 + backends/v3/src/backend.rs | 55 ++++-- backends/v3/src/client/grpc_client.rs | 1 + backends/v3/src/client/mod.rs | 9 - backends/v3/src/client/sharded_client.rs | 21 +- backends/v3/src/lib.rs | 4 + backends/v3/src/main.rs | 28 +-- backends/v3/src/queue.rs | 141 +++++++------- benchmark/src/generation.rs | 1 + proto/v3/generate.proto | 5 + server/tests/conftest.py | 2 +- .../models/causal_lm.py | 1 + .../models/flash_causal_lm.py | 181 +++++++++++------- .../text_generation_server/models/globals.py | 2 +- .../models/idefics_causal_lm.py | 1 + server/text_generation_server/models/mamba.py | 1 + server/text_generation_server/models/model.py | 16 ++ .../models/seq2seq_lm.py | 1 + .../models/vlm_causal_lm.py | 3 +- server/text_generation_server/server.py | 3 + .../utils/prefill_chunking.py | 24 +++ 22 files changed, 307 insertions(+), 195 deletions(-) create mode 100644 server/text_generation_server/utils/prefill_chunking.py diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 479d31bf..61d1ea1b 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -159,6 +159,7 @@ impl Client { blocks: vec![], slots: vec![], prefix_len: 0, + postfix_len: truncate, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 645c076a..8872f8bd 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -246,6 +246,7 @@ impl Health for ShardedClient { blocks: vec![0], slots: (0..16).collect(), prefix_len: 0, + postfix_len: 1, adapter_id: None, }; let batch = Batch { diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 77fdb041..bfe7932f 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -34,9 +34,13 @@ impl BackendV3 { requires_padding: bool, window_size: Option, speculate: u32, + support_chunking: bool, ) -> Self { - let prefix_caching = - std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string()); + if support_chunking { + tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored."); + } + + let prefix_caching = std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string()); let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string()); @@ -52,6 +56,7 @@ impl BackendV3 { window_size, speculate, max_batch_total_tokens, + support_chunking, ); let batching_task_notifier = Arc::new(Notify::new()); @@ -63,6 +68,7 @@ impl BackendV3 { max_batch_total_tokens, max_waiting_tokens, max_batch_size, + support_chunking, queue.clone(), batching_task_notifier.clone(), )); @@ -127,6 +133,7 @@ pub(crate) async fn batching_task( max_batch_total_tokens: u32, max_waiting_tokens: usize, max_batch_size: Option, + support_chunking: bool, queue: Queue, notifier: Arc, ) { @@ -158,28 +165,44 @@ pub(crate) async fn batching_task( // Get current batch info let batch_size = batch.size; let batch_max_tokens = batch.max_tokens; + let current_tokens = batch.current_tokens; let mut batches = vec![batch]; metrics::gauge!("tgi_batch_current_size").set(batch_size as f64); metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64); - let min_size = if waiting_tokens >= max_waiting_tokens { - // If we didn't onboard any new requests since >= max_waiting_tokens, we try - // to add a new batch even though its size might be small - None - } else { - // Minimum batch size - // TODO: temporarily disable to avoid incorrect deallocation + - // reallocation when using prefix caching. - Some((batch_size as f32 * waiting_served_ratio).floor() as usize) - }; - let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens); - let max_size = - max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); + + let (min_size, max_size, prefill_token_budget) = if support_chunking { + // Since the next batch will be concatenated with the current batch, + // the current batch tokens must be subtracted to the prefill budget + // In the future, we could concatenate beforehand + let prefill_token_budget = max_batch_prefill_tokens - current_tokens; + // We can ignore min_size and max_size + // Models than rely on max_size cannot support chunking + // Regarding min_size, chunking allow us to consistently run at the compute + // bound, making min_size useless. + (None, None, prefill_token_budget) + } else { + let min_size = if waiting_tokens >= max_waiting_tokens { + // If we didn't onboard any new requests since >= max_waiting_tokens, we try + // to add a new batch even though its size might be small + None + } else { + // Minimum batch size + // TODO: temporarily disable to avoid incorrect deallocation + + // reallocation when using prefix caching. + Some((batch_size as f32 * waiting_served_ratio).floor() as usize) + }; + + let max_size = + max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize)); + + (min_size, max_size, max_batch_prefill_tokens) + }; // Try to get a new batch if let Some((mut new_entries, new_batch, span)) = queue - .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget) + .next_batch(min_size, max_size, prefill_token_budget, token_budget) .await { // Tracking metrics diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index 648662db..3b4432a7 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -159,6 +159,7 @@ impl Client { blocks: vec![], slots: vec![], prefix_len: 0, + postfix_len: truncate, // Set sampling parameters to also take these ops into account in the max memory parameters: Some(NextTokenChooserParameters { temperature: 0.9, diff --git a/backends/v3/src/client/mod.rs b/backends/v3/src/client/mod.rs index 755431f4..d4ac50c9 100644 --- a/backends/v3/src/client/mod.rs +++ b/backends/v3/src/client/mod.rs @@ -29,15 +29,6 @@ pub trait Health { async fn model_health(&self) -> Result<()>; } -#[derive(Debug)] -pub struct ShardInfo { - pub requires_padding: bool, - pub dtype: String, - pub device_type: String, - pub window_size: Option, - pub speculate: u32, -} - #[derive(Error, Debug, Clone)] pub enum ClientError { #[error("Could not connect to Text Generation server: {0}")] diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index ea77a696..97a1eab6 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -1,6 +1,6 @@ -use crate::client::{ClientError, Result}; +use crate::client::Health; /// Multi shard Client -use crate::client::{Health, ShardInfo}; +use crate::client::{ClientError, Result}; use crate::client::grpc_client::{DecodeTimings, PrefillTimings}; use crate::client::{ @@ -49,13 +49,13 @@ impl ShardedClient { /// Get the model info #[instrument(skip(self))] - pub async fn info(&mut self) -> Result { + pub async fn info(&mut self) -> Result { let futures: Vec<_> = self .clients .iter_mut() .map(|client| client.info()) .collect(); - join_all(futures).await.pop().unwrap().map(ShardInfo::from) + join_all(futures).await.pop().unwrap() } /// GRPC health check @@ -194,18 +194,6 @@ impl ShardedClient { } } -impl From for ShardInfo { - fn from(value: InfoResponse) -> Self { - Self { - requires_padding: value.requires_padding, - dtype: value.dtype, - device_type: value.device_type, - window_size: value.window_size, - speculate: value.speculate, - } - } -} - #[async_trait] impl Health for ShardedClient { async fn device_health(&self) -> Result<()> { @@ -248,6 +236,7 @@ impl Health for ShardedClient { slots: (0..16).collect(), prefix_len: 0, adapter_id: None, + postfix_len: 1, }; let batch = Batch { id: u64::MAX, diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index af66b21e..0a7ef223 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -29,6 +29,8 @@ pub struct BackendInfo { pub max_waiting_tokens: usize, #[schema(nullable = true, example = "null")] pub max_batch_size: Option, + #[schema(example = "false")] + pub support_chunking: bool, } #[allow(clippy::too_many_arguments)] @@ -110,6 +112,7 @@ pub async fn connect_backend( model_device_type: shard_info.device_type.clone(), model_dtype: shard_info.dtype.clone(), speculate: shard_info.speculate as usize, + support_chunking: shard_info.support_chunking, }; let backend = BackendV3::new( @@ -122,6 +125,7 @@ pub async fn connect_backend( shard_info.requires_padding, shard_info.window_size, shard_info.speculate, + shard_info.support_chunking, ); tracing::info!("Using backend V3"); diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 471ddb5a..b4751bd5 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> { "`max_input_tokens` must be < `max_total_tokens`".to_string(), )); } - if max_input_tokens as u32 > max_batch_prefill_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); - } if validation_workers == 0 { return Err(RouterError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), )); } - - if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { - if max_batch_prefill_tokens > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); - } - if max_total_tokens as u32 > *max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); - } - } - if let Some(max_batch_size) = max_batch_size { if max_batch_size == 0 { return Err(RouterError::ArgumentValidation( @@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> { } } - let (backend, _backend_info) = connect_backend( + let (backend, backend_info) = connect_backend( max_input_tokens, max_total_tokens, master_shard_uds_path, @@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> { ) .await?; + // Validate remaining args now that the backend is known + let support_chunking = backend_info.support_chunking; + let max_batch_total_tokens = backend_info.max_batch_total_tokens; + if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); + } + if max_batch_prefill_tokens > max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + // Run server server::run( backend, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index f8123b57..7db0aba3 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -4,7 +4,7 @@ use crate::client::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use nohash_hasher::{BuildNoHashHasher, IntMap}; -use std::cmp::{max, min}; +use std::cmp::max; use std::collections::VecDeque; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; @@ -50,6 +50,7 @@ impl Queue { window_size: Option, speculate: u32, max_batch_total_tokens: u32, + support_chunking: bool, ) -> Self { // Create channel let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); @@ -62,6 +63,7 @@ impl Queue { window_size, speculate, max_batch_total_tokens, + support_chunking, queue_receiver, )); @@ -108,6 +110,7 @@ impl Queue { } // Background task responsible of the queue state +#[allow(clippy::too_many_arguments)] async fn queue_task( requires_padding: bool, block_size: u32, @@ -115,6 +118,7 @@ async fn queue_task( window_size: Option, speculate: u32, max_batch_total_tokens: u32, + support_chunking: bool, mut receiver: mpsc::UnboundedReceiver, ) { let mut state = State::new( @@ -124,6 +128,7 @@ async fn queue_task( window_size, speculate, max_batch_total_tokens, + support_chunking, ); while let Some(cmd) = receiver.recv().await { @@ -166,12 +171,14 @@ struct State { /// Paged Attention block size block_size: u32, - /// Sliding window - window_size: Option, - /// Speculation amount speculate: u32, + /// Whether the model allow the prefill chunking + /// If it does, the last request in the batch will be split to exactly match the prefill + /// token budget + support_chunking: bool, + /// Paged Attention Block Allocation block_allocator: Option, } @@ -184,6 +191,7 @@ impl State { window_size: Option, speculate: u32, max_batch_total_tokens: u32, + support_chunking: bool, ) -> Self { let block_allocator = (!requires_padding).then(|| { BlockAllocator::new( @@ -199,8 +207,8 @@ impl State { next_id: 0, next_batch_id: 0, block_size, - window_size, speculate, + support_chunking, block_allocator, } } @@ -268,7 +276,7 @@ impl State { continue; } - let block_allocation = match &self.block_allocator { + let (block_allocation, postfix_len) = match &self.block_allocator { None => { // We pad to max input length in the Python shards // We need to take these padding tokens into the equation @@ -285,34 +293,9 @@ impl State { self.entries.push_front((id, entry)); break 'entry_loop; } - None + (None, entry.request.input_length) } - Some(_block_allocator) => { - prefill_tokens += entry.request.input_length; - let max_new_tokens = match self.window_size { - None => entry.request.stopping_parameters.max_new_tokens, - Some(window_size) => min( - window_size.saturating_sub(entry.request.input_length), - entry.request.stopping_parameters.max_new_tokens, - ), - }; - decode_tokens += max_new_tokens; - - if prefill_tokens > prefill_token_budget - || (prefill_tokens + decode_tokens + self.speculate) > token_budget - { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); - self.entries.push_front((id, entry)); - break; - } - - let tokens = entry.request.input_length - + entry.request.stopping_parameters.max_new_tokens - + self.speculate - - 1; - + Some(block_allocator) => { // If users wants the prefill logprobs, we cannot reuse the cache. // So no input_ids for the radix tree. let input_ids = if entry.request.decoder_input_details { @@ -321,10 +304,65 @@ impl State { entry.request.input_ids.clone() }; - Some((tokens, input_ids)) + let tokens = entry.request.input_length + + entry.request.stopping_parameters.max_new_tokens + + self.speculate + - 1; + tracing::debug!("Allocating {tokens} with {input_ids:?}"); + + let block_allocation = match block_allocator.allocate(tokens, input_ids).await { + None => { + // Entry is over budget + // Add it back to the front + tracing::debug!("Over budget: not enough free blocks"); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + Some(mut block_allocation) => { + tracing::debug!("Allocation: {block_allocation:?}"); + max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); + + if block_allocation.prefix_len == entry.request.input_length { + // The whole request was found in the radix trie + // However, for the transformer forward to work, we need to + // have at least one token of postfix. + block_allocation.prefix_len -= 1; + } + + block_allocation + } + }; + + let mut postfix_len = entry.request.input_length - block_allocation.prefix_len; + + // Check equality too as if we don't we might end up with a postfix_len = 0 + // in the next iteration of the loop + if prefill_tokens + postfix_len >= prefill_token_budget { + // Entry is over budget + if self.support_chunking { + // We support chunking, just set postfix_len to exactly match prefill_token_budget + postfix_len = prefill_token_budget - prefill_tokens; + // Push this entry inside the batch + batch.push((id, entry, Some(block_allocation), postfix_len)); + break 'entry_loop; + } else { + // We don't support chunking, this entry needs to go back to the buffer + // Add it back to the front + tracing::debug!( + "Over budget: prefill_tokens={} > {prefill_token_budget}", + prefill_tokens + postfix_len + ); + self.entries.push_front((id, entry)); + break 'entry_loop; + } + } + + prefill_tokens += postfix_len; + + (Some(block_allocation), postfix_len) } }; - batch.push((id, entry, block_allocation)); + batch.push((id, entry, block_allocation, postfix_len)); if Some(batch.len()) == max_size { break; } @@ -342,7 +380,7 @@ impl State { // Batch is too small if batch.len() < min_size { // Add back entries to the queue in the correct order - for (id, entry, _) in batch.into_iter().rev() { + for (id, entry, _, _) in batch.into_iter().rev() { self.entries.push_front((id, entry)); } return None; @@ -353,29 +391,7 @@ impl State { let mut batch_entries = IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default()); - for (id, mut entry, block_allocation) in batch { - let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) = - (block_allocation, &self.block_allocator) - { - tracing::debug!("Allocating {tokens} with {input_ids:?}"); - match block_allocator.allocate(tokens, input_ids).await { - None => { - // Entry is over budget - // Add it back to the front - tracing::debug!("Over budget: not enough free blocks"); - self.entries.push_front((id, entry)); - continue; - } - Some(block_allocation) => { - tracing::debug!("Allocation: {block_allocation:?}"); - max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); - Some(block_allocation) - } - } - } else { - None - }; - tracing::debug!("Accepting entry"); + for (id, mut entry, block_allocation, postfix_len) in batch { // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); // Add relationships @@ -429,6 +445,7 @@ impl State { slots, prefix_len, adapter_id: entry.request.adapter_id.clone(), + postfix_len, }); // Set batch_time entry.batch_time = Some(Instant::now()); @@ -436,12 +453,6 @@ impl State { batch_entries.insert(id, entry); } - // Empty batch - if batch_requests.is_empty() { - tracing::debug!("Filterered out all entries"); - return None; - } - // Final batch size let size = batch_requests.len() as u32; next_batch_span.record("batch_size", size); diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 789c7b51..fff221ef 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -159,6 +159,7 @@ async fn prefill( blocks: vec![], slots: vec![], prefix_len: 0, + postfix_len: sequence_length, adapter_id: None, }) .collect(); diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 34894bda..cfb92ba8 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -34,6 +34,7 @@ message InfoResponse { string device_type = 3; optional uint32 window_size = 4; uint32 speculate = 5; + bool support_chunking = 6; } /// Empty request @@ -139,6 +140,8 @@ message Request { uint32 prefix_len = 12; /// Context truncation bool add_special_tokens = 13; + /// Postfix length for prefill chunking + uint32 postfix_len = 14; } message Batch { @@ -163,6 +166,8 @@ message CachedBatch { uint32 size = 3; /// Maximum number of tokens this batch will grow to uint32 max_tokens = 4; + /// Number of tokens in the next forward + uint32 current_tokens = 5; } enum FinishReason { diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 1efeba58..b1a30e02 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,7 +1,7 @@ import pytest -import os from text_generation_server.pb import generate_pb2 + @pytest.fixture def default_pb_parameters(): return generate_pb2.NextTokenChooserParameters( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 28534d0f..1378f590 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -76,6 +76,7 @@ class CausalLMBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, + current_tokens=len(self), ) @classmethod diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8ee9d184..b39fe0ff 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -16,7 +16,17 @@ from transformers import ( AutoTokenizer, GenerationConfig, ) -from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict, Union +from typing import ( + Any, + ContextManager, + Iterable, + Optional, + Tuple, + List, + Type, + Dict, + Union, +) from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -24,6 +34,10 @@ from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.log import log_master +from text_generation_server.utils.prefill_chunking import ( + get_support_chunking, + get_max_prefill_tokens, +) from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( @@ -60,12 +74,9 @@ from text_generation_server.utils.import_utils import ( tracer = trace.get_tracer(__name__) - # Will be set in init SLIDING_WINDOW: Optional[int] = None -TOKEN_BUDGET = 8 - def set_sliding_window(sliding_window: int): global SLIDING_WINDOW @@ -206,6 +217,11 @@ class FlashCausalLMBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.num_blocks * BLOCK_SIZE, + current_tokens=( + sum([len(i) for i in self.input_ids]) + if isinstance(self.input_ids, list) + else len(self.input_ids) + ), ) @classmethod @@ -272,7 +288,7 @@ class FlashCausalLMBatch(Batch): prompt_lengths.append(prompt_length) prefix_length = r.prefix_len - postfix_length = prefix_length + 10 + postfix_length = r.postfix_len assert ( prefix_length <= prompt_length ), f"Prefix {prefix_length} vs input {prompt_length}" @@ -282,10 +298,13 @@ class FlashCausalLMBatch(Batch): if prefix_length + postfix_length < prompt_length: # FIXME: speculate is not supported for context chunking at the moment assert speculate == 0 + assert get_support_chunking() + assert postfix_length > 0 prefix_ids.append(tokenized_input[:prefix_length]) - postfix_ids = tokenized_input[prefix_length : postfix_length] - # postfix_ids = tokenized_input[prefix_length:] + postfix_ids = tokenized_input[ + prefix_length : prefix_length + postfix_length + ] postfix_length = len(postfix_ids) postfix_lengths.append(postfix_length) @@ -371,7 +390,6 @@ class FlashCausalLMBatch(Batch): requests=pb.requests, requests_idx_mapping=requests_idx_mapping, input_ids=all_postfix_ids, - block_tables=block_tables, block_tables_tensor=block_tables_tensor, prefix_lengths=prefix_lengths, @@ -395,7 +413,6 @@ class FlashCausalLMBatch(Batch): max_blocks=max_blocks, speculative_ids=None, prompt_lengths_tensor=prompt_lengths_tensor, - # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids=None, cu_seqlen_prefill=None, @@ -431,7 +448,7 @@ class FlashCausalLMBatch(Batch): if len(request_ids) == len(self): return self - device = self.input_ids.device + device = self.block_tables_tensor.device # New values after filtering requests_idx_mapping = {} @@ -552,13 +569,13 @@ class FlashCausalLMBatch(Batch): if self.prefilling: # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` - position_ids=None - start_slots=None - slot_indices=None - slots=None - prefix_lengths_tensor=None - postfix_lengths_tensor=None - adapter_meta=None + position_ids = None + start_slots = None + slot_indices = None + slots = None + prefix_lengths_tensor = None + postfix_lengths_tensor = None + adapter_meta = None else: # Index into tensors input_ids = self.input_ids[indices] @@ -643,24 +660,24 @@ class FlashCausalLMBatch(Batch): max_current_length = 0 for b in batches: total_batch_size += len(b) - total_slots += len(b.slots) + max_blocks = max(max_blocks, b.max_blocks) + # If `b` is prefilling and was just filtered, `b.slots` is None + # `total_slots` is not used if any of the batches is prefilling + total_slots += len(b.slots) if not b.prefilling else 0 num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 ) - max_blocks = max(max_blocks, b.max_blocks) max_postfix_length = max(max_postfix_length, b.max_postfix_length) max_current_length = max(max_current_length, b.max_current_length) max_length = max( max_length, max( - prefix_length - + postfix_length + prompt_length + stopping_criteria.max_new_tokens + speculative_length - - stopping_criteria.current_tokens - for prefix_length, postfix_length, stopping_criteria in zip( - b.prefix_lengths, b.postfix_lengths, b.stopping_criterias + for prompt_length, stopping_criteria in zip( + b.prompt_lengths, b.stopping_criterias ) ), ) @@ -669,14 +686,14 @@ class FlashCausalLMBatch(Batch): if prefilling: input_ids = [] # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` - position_ids=None - start_slots=None - slots=None - slot_indices=None - prefix_lengths_tensor=None - postfix_lengths_tensor=None - adapter_meta=None - adapter_segment_builder=None + position_ids = None + start_slots = None + slots = None + slot_indices = None + prefix_lengths_tensor = None + postfix_lengths_tensor = None + adapter_meta = None + adapter_segment_builder = None else: input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) @@ -746,8 +763,6 @@ class FlashCausalLMBatch(Batch): start_index = cumulative_batch_size end_index = cumulative_batch_size + len(batch) - slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots) # Copy tensors (GPU) top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor @@ -761,10 +776,17 @@ class FlashCausalLMBatch(Batch): prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor if not prefilling: + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) + input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids - slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots - postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor + slot_indices[start_index:end_index] = ( + batch.slot_indices + cumulative_slots + ) + postfix_lengths_tensor[start_index:end_index] = ( + batch.postfix_lengths_tensor + ) slots[slots_start_index:slots_end_index] = batch.slots # Copy over adapter indices @@ -779,11 +801,17 @@ class FlashCausalLMBatch(Batch): cumulative_adapter_indices_size = adapter_end_index adapter_set.update(batch.adapter_meta.adapter_set) adapter_segment_builder.concat( - batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices + batch.adapter_meta.adapter_segments, + batch.adapter_meta.segment_indices, + ) + prefix_lengths_tensor[start_index:end_index] = ( + batch.prefix_lengths_tensor ) - prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor start_slots.append(batch.start_slots + cumulative_slots) + + # Update + cumulative_slots += len(batch.slots) else: if isinstance(batch.input_ids, torch.Tensor): batch.input_ids = batch.input_ids.view(-1, 1).tolist() @@ -810,7 +838,6 @@ class FlashCausalLMBatch(Batch): # Update cumulative_batch_size += len(batch) - cumulative_slots += len(batch.slots) if start_slots is not None: start_slots = torch.concat(start_slots) @@ -915,7 +942,7 @@ class FlashCausalLMBatch(Batch): postfix_length, prompt_length, request_prefilling, - blocks + blocks, ) in enumerate( zip( self.requests, @@ -923,7 +950,7 @@ class FlashCausalLMBatch(Batch): self.postfix_lengths, self.prompt_lengths, self.prefilling_mask, - self.block_tables + self.block_tables, ) ): next_chunk_length = postfix_length @@ -967,9 +994,7 @@ class FlashCausalLMBatch(Batch): no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs if prefill_logprobs: - prefill_head_indices.append( - request_position_ids + cumulative_length - ) + prefill_head_indices.append(request_position_ids + cumulative_length) prefill_next_token_indices.append( prefill_out_cumulative_length + postfix_length - 1 ) @@ -988,7 +1013,6 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - start_slots.append(cumulative_slot_tokens) slots.extend(request_slots) slot_indices.append(request_slot_indices) @@ -998,9 +1022,7 @@ class FlashCausalLMBatch(Batch): ADAPTER_TO_INDEX = get_adapter_to_index() adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append( - torch.full((next_chunk_length,), adapter_index) - ) + adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index)) adapter_set.add(adapter_index) # Update @@ -1240,6 +1262,7 @@ class FlashCausalLM(Model): rank=rank, world_size=world_size, sliding_window=config.sliding_window, + support_chunking=True, ) @property @@ -1764,29 +1787,43 @@ class FlashCausalLM(Model): finished_prefilling = True next_chunk_lengths = [] if prefill: - next_prefilling_mask = [] - # Budget in tokens for the next batch - # We remove next input ids to always have enough space for at least a single decode - # for the remaining requests - batch_budget = TOKEN_BUDGET - len(batch) - for prefix_length, postfix_length, prompt_length in zip( - batch.prefix_lengths, batch.postfix_lengths, batch.prompt_lengths - ): - remaining_prefill_tokens = max( - prompt_length - prefix_length - postfix_length, 0 - ) - if remaining_prefill_tokens > 0: - next_chunk_length = max( - min(remaining_prefill_tokens, batch_budget), 1 + if get_support_chunking(): + next_prefilling_mask = [] + # Budget in tokens for the next batch + # We remove len(batch) to always have enough space for at least a single decode + # for the remaining requests + batch_budget = get_max_prefill_tokens() - len(batch) + # We reverse to prioritize older requests + # zip() is not reversible so reverse the underlying lists instead + for prefix_length, postfix_length, prompt_length in zip( + reversed(batch.prefix_lengths), + reversed(batch.postfix_lengths), + reversed(batch.prompt_lengths), + ): + remaining_prefill_tokens = max( + prompt_length - prefix_length - postfix_length, 0 ) - batch_budget -= next_chunk_length - finished_prefilling = False - next_prefilling_mask.append(True) - else: - # Since speculation will be turned off, this is always true - next_chunk_length = 1 - next_prefilling_mask.append(False) - next_chunk_lengths.append(next_chunk_length) + if remaining_prefill_tokens > 0: + next_chunk_length = max( + min(remaining_prefill_tokens, batch_budget), 1 + ) + batch_budget -= next_chunk_length + finished_prefilling = False + next_prefilling_mask.append(True) + else: + # Since speculation will be turned off, this is always true + next_chunk_length = 1 + next_prefilling_mask.append(False) + next_chunk_lengths.append(next_chunk_length) + + # Reverse back the obtained values² + next_chunk_lengths.reverse() + next_prefilling_mask.reverse() + else: + # The model does not support chunking + # We know we only do a single prefill + finished_prefilling = True + next_prefilling_mask = [False] * len(batch) batch.prefilling = not finished_prefilling batch.prefilling_mask = next_prefilling_mask @@ -2179,7 +2216,9 @@ class FlashCausalLM(Model): # have more than one new token per request with speculative decoding for next_token_id in _next_token_ids: batch.next_token_chooser = ( - batch.next_token_chooser.advance_grammar_single(i, next_token_id) + batch.next_token_chooser.advance_grammar_single( + i, next_token_id + ) ) # Update values diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 1830dc42..6bf8d3ff 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -18,7 +18,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: raise RuntimeError("Prefix caching is only supported with flashinfer") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None -TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) +TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90")) assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM < 1 diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 9a7a6fe1..34b74ba8 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, + current_tokens=len(self), ) @classmethod diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index f6dcde68..dfc61fb8 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -116,6 +116,7 @@ class MambaBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, + current_tokens=len(self), ) @classmethod diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 20402e07..02f3dbf9 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -5,8 +5,11 @@ from abc import ABC, abstractmethod from typing import List, Tuple, Optional, TypeVar, Type, Dict from collections import defaultdict from transformers import PreTrainedTokenizerBase +from loguru import logger from text_generation_server.models.types import Batch, Generation +from text_generation_server.utils.log import log_master +from text_generation_server.utils.prefill_chunking import set_support_chunking from text_generation_server.utils.speculate import get_speculate from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.adapters.weights import LayerAdapterWeights @@ -31,6 +34,7 @@ class Model(ABC): sliding_window: Optional[int] = None, speculate: Optional[int] = None, adapter_id: str = BASE_MODEL_ADAPTER_ID, + support_chunking: bool = False, ): self.model_id = model_id self.model = model.eval() @@ -60,6 +64,17 @@ class Model(ABC): speculate = get_speculate() self.speculate = speculate + if speculate != 0 and support_chunking: + log_master( + logger.warning, + "Prefill chunking does not support speculation yet. " + "Prefill chunking will be turned off", + ) + support_chunking = False + + self.support_chunking = support_chunking + set_support_chunking(support_chunking) + self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) is not None @@ -78,6 +93,7 @@ class Model(ABC): device_type=self.device.type, window_size=self.sliding_window, speculate=self.speculate, + support_chunking=self.support_chunking, ) @property diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 04d4c28b..e2d7aa4d 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch): request_ids=[r.id for r in self.requests], size=len(self), max_tokens=self.max_tokens, + current_tokens=len(self), ) @classmethod diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 937811d7..1a578d7b 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -357,7 +357,6 @@ class VlmCausalLM(FlashCausalLM): else: cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = postfix_lengths + prefix_lengths_tensor if PREFIX_CACHING: block_tables = block_tables_to_ragged( block_tables=block_tables, @@ -424,7 +423,7 @@ class VlmCausalLM(FlashCausalLM): cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths cuda_graph["prefix_lengths"].zero_() cuda_graph["prefix_lengths"][ - : prefix_lengths_tensor.shape[0] + : prefix_lengths_tensor.shape[0] ] = prefix_lengths_tensor with self._forward_context( diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 46e342a4..bd4b3a53 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -15,6 +15,7 @@ from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.utils.adapter import AdapterInfo +from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens try: from text_generation_server.models.pali_gemma import PaliGemmaBatch @@ -96,6 +97,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): + set_max_prefill_tokens(request.max_prefill_tokens) + if self.quantize in {"exl2", "gptq"}: try: # When using GPTQ, Exllama kernels need some global kernels diff --git a/server/text_generation_server/utils/prefill_chunking.py b/server/text_generation_server/utils/prefill_chunking.py new file mode 100644 index 00000000..c227d30f --- /dev/null +++ b/server/text_generation_server/utils/prefill_chunking.py @@ -0,0 +1,24 @@ +from typing import Optional + +SUPPORT_CHUNKING: Optional[bool] = None +MAX_PREFILL_TOKENS: Optional[int] = None + + +def set_support_chunking(support_chunking: bool): + global SUPPORT_CHUNKING + SUPPORT_CHUNKING = support_chunking + + +def get_support_chunking() -> bool: + global SUPPORT_CHUNKING + return SUPPORT_CHUNKING + + +def set_max_prefill_tokens(max_prefill_tokens: int): + global MAX_PREFILL_TOKENS + MAX_PREFILL_TOKENS = max_prefill_tokens + + +def get_max_prefill_tokens() -> int: + global MAX_PREFILL_TOKENS + return MAX_PREFILL_TOKENS