mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
load tested
This commit is contained in:
parent
34f5dc525e
commit
7f9abde3f8
@ -159,6 +159,7 @@ impl Client {
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
postfix_len: truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
@ -246,6 +246,7 @@ impl Health for ShardedClient {
|
||||
blocks: vec![0],
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
postfix_len: 1,
|
||||
adapter_id: None,
|
||||
};
|
||||
let batch = Batch {
|
||||
|
@ -34,9 +34,13 @@ impl BackendV3 {
|
||||
requires_padding: bool,
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
support_chunking: bool,
|
||||
) -> Self {
|
||||
let prefix_caching =
|
||||
std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
|
||||
if support_chunking {
|
||||
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
|
||||
}
|
||||
|
||||
let prefix_caching = std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
|
||||
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
|
||||
let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string());
|
||||
|
||||
@ -52,6 +56,7 @@ impl BackendV3 {
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
support_chunking,
|
||||
);
|
||||
let batching_task_notifier = Arc::new(Notify::new());
|
||||
|
||||
@ -63,6 +68,7 @@ impl BackendV3 {
|
||||
max_batch_total_tokens,
|
||||
max_waiting_tokens,
|
||||
max_batch_size,
|
||||
support_chunking,
|
||||
queue.clone(),
|
||||
batching_task_notifier.clone(),
|
||||
));
|
||||
@ -127,6 +133,7 @@ pub(crate) async fn batching_task(
|
||||
max_batch_total_tokens: u32,
|
||||
max_waiting_tokens: usize,
|
||||
max_batch_size: Option<usize>,
|
||||
support_chunking: bool,
|
||||
queue: Queue,
|
||||
notifier: Arc<Notify>,
|
||||
) {
|
||||
@ -158,10 +165,24 @@ pub(crate) async fn batching_task(
|
||||
// Get current batch info
|
||||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let current_tokens = batch.current_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
|
||||
let (min_size, max_size, prefill_token_budget) = if support_chunking {
|
||||
// Since the next batch will be concatenated with the current batch,
|
||||
// the current batch tokens must be subtracted to the prefill budget
|
||||
// In the future, we could concatenate beforehand
|
||||
let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
|
||||
// We can ignore min_size and max_size
|
||||
// Models than rely on max_size cannot support chunking
|
||||
// Regarding min_size, chunking allow us to consistently run at the compute
|
||||
// bound, making min_size useless.
|
||||
(None, None, prefill_token_budget)
|
||||
} else {
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
// to add a new batch even though its size might be small
|
||||
@ -173,13 +194,15 @@ pub(crate) async fn batching_task(
|
||||
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
|
||||
};
|
||||
|
||||
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
|
||||
let max_size =
|
||||
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
|
||||
|
||||
(min_size, max_size, max_batch_prefill_tokens)
|
||||
};
|
||||
|
||||
// Try to get a new batch
|
||||
if let Some((mut new_entries, new_batch, span)) = queue
|
||||
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
|
||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||
.await
|
||||
{
|
||||
// Tracking metrics
|
||||
|
@ -159,6 +159,7 @@ impl Client {
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
postfix_len: truncate,
|
||||
// Set sampling parameters to also take these ops into account in the max memory
|
||||
parameters: Some(NextTokenChooserParameters {
|
||||
temperature: 0.9,
|
||||
|
@ -29,15 +29,6 @@ pub trait Health {
|
||||
async fn model_health(&self) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ShardInfo {
|
||||
pub requires_padding: bool,
|
||||
pub dtype: String,
|
||||
pub device_type: String,
|
||||
pub window_size: Option<u32>,
|
||||
pub speculate: u32,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum ClientError {
|
||||
#[error("Could not connect to Text Generation server: {0}")]
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::client::{ClientError, Result};
|
||||
use crate::client::Health;
|
||||
/// Multi shard Client
|
||||
use crate::client::{Health, ShardInfo};
|
||||
use crate::client::{ClientError, Result};
|
||||
|
||||
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
|
||||
use crate::client::{
|
||||
@ -49,13 +49,13 @@ impl ShardedClient {
|
||||
|
||||
/// Get the model info
|
||||
#[instrument(skip(self))]
|
||||
pub async fn info(&mut self) -> Result<ShardInfo> {
|
||||
pub async fn info(&mut self) -> Result<InfoResponse> {
|
||||
let futures: Vec<_> = self
|
||||
.clients
|
||||
.iter_mut()
|
||||
.map(|client| client.info())
|
||||
.collect();
|
||||
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
|
||||
join_all(futures).await.pop().unwrap()
|
||||
}
|
||||
|
||||
/// GRPC health check
|
||||
@ -194,18 +194,6 @@ impl ShardedClient {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InfoResponse> for ShardInfo {
|
||||
fn from(value: InfoResponse) -> Self {
|
||||
Self {
|
||||
requires_padding: value.requires_padding,
|
||||
dtype: value.dtype,
|
||||
device_type: value.device_type,
|
||||
window_size: value.window_size,
|
||||
speculate: value.speculate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Health for ShardedClient {
|
||||
async fn device_health(&self) -> Result<()> {
|
||||
@ -248,6 +236,7 @@ impl Health for ShardedClient {
|
||||
slots: (0..16).collect(),
|
||||
prefix_len: 0,
|
||||
adapter_id: None,
|
||||
postfix_len: 1,
|
||||
};
|
||||
let batch = Batch {
|
||||
id: u64::MAX,
|
||||
|
@ -29,6 +29,8 @@ pub struct BackendInfo {
|
||||
pub max_waiting_tokens: usize,
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub max_batch_size: Option<usize>,
|
||||
#[schema(example = "false")]
|
||||
pub support_chunking: bool,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@ -110,6 +112,7 @@ pub async fn connect_backend(
|
||||
model_device_type: shard_info.device_type.clone(),
|
||||
model_dtype: shard_info.dtype.clone(),
|
||||
speculate: shard_info.speculate as usize,
|
||||
support_chunking: shard_info.support_chunking,
|
||||
};
|
||||
|
||||
let backend = BackendV3::new(
|
||||
@ -122,6 +125,7 @@ pub async fn connect_backend(
|
||||
shard_info.requires_padding,
|
||||
shard_info.window_size,
|
||||
shard_info.speculate,
|
||||
shard_info.support_chunking,
|
||||
);
|
||||
|
||||
tracing::info!("Using backend V3");
|
||||
|
@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> {
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`validation_workers` must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_batch_size) = max_batch_size {
|
||||
if max_batch_size == 0 {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> {
|
||||
}
|
||||
}
|
||||
|
||||
let (backend, _backend_info) = connect_backend(
|
||||
let (backend, backend_info) = connect_backend(
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
master_shard_uds_path,
|
||||
@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Validate remaining args now that the backend is known
|
||||
let support_chunking = backend_info.support_chunking;
|
||||
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
if max_batch_prefill_tokens > max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
if max_total_tokens as u32 > max_batch_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
|
||||
}
|
||||
|
||||
// Run server
|
||||
server::run(
|
||||
backend,
|
||||
|
@ -4,7 +4,7 @@ use crate::client::{
|
||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::{max, min};
|
||||
use std::cmp::max;
|
||||
use std::collections::VecDeque;
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
@ -50,6 +50,7 @@ impl Queue {
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
support_chunking: bool,
|
||||
) -> Self {
|
||||
// Create channel
|
||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||
@ -62,6 +63,7 @@ impl Queue {
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
support_chunking,
|
||||
queue_receiver,
|
||||
));
|
||||
|
||||
@ -108,6 +110,7 @@ impl Queue {
|
||||
}
|
||||
|
||||
// Background task responsible of the queue state
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn queue_task(
|
||||
requires_padding: bool,
|
||||
block_size: u32,
|
||||
@ -115,6 +118,7 @@ async fn queue_task(
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
support_chunking: bool,
|
||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||
) {
|
||||
let mut state = State::new(
|
||||
@ -124,6 +128,7 @@ async fn queue_task(
|
||||
window_size,
|
||||
speculate,
|
||||
max_batch_total_tokens,
|
||||
support_chunking,
|
||||
);
|
||||
|
||||
while let Some(cmd) = receiver.recv().await {
|
||||
@ -166,12 +171,14 @@ struct State {
|
||||
/// Paged Attention block size
|
||||
block_size: u32,
|
||||
|
||||
/// Sliding window
|
||||
window_size: Option<u32>,
|
||||
|
||||
/// Speculation amount
|
||||
speculate: u32,
|
||||
|
||||
/// Whether the model allow the prefill chunking
|
||||
/// If it does, the last request in the batch will be split to exactly match the prefill
|
||||
/// token budget
|
||||
support_chunking: bool,
|
||||
|
||||
/// Paged Attention Block Allocation
|
||||
block_allocator: Option<BlockAllocator>,
|
||||
}
|
||||
@ -184,6 +191,7 @@ impl State {
|
||||
window_size: Option<u32>,
|
||||
speculate: u32,
|
||||
max_batch_total_tokens: u32,
|
||||
support_chunking: bool,
|
||||
) -> Self {
|
||||
let block_allocator = (!requires_padding).then(|| {
|
||||
BlockAllocator::new(
|
||||
@ -199,8 +207,8 @@ impl State {
|
||||
next_id: 0,
|
||||
next_batch_id: 0,
|
||||
block_size,
|
||||
window_size,
|
||||
speculate,
|
||||
support_chunking,
|
||||
block_allocator,
|
||||
}
|
||||
}
|
||||
@ -268,7 +276,7 @@ impl State {
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_allocation = match &self.block_allocator {
|
||||
let (block_allocation, postfix_len) = match &self.block_allocator {
|
||||
None => {
|
||||
// We pad to max input length in the Python shards
|
||||
// We need to take these padding tokens into the equation
|
||||
@ -285,34 +293,9 @@ impl State {
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
None
|
||||
(None, entry.request.input_length)
|
||||
}
|
||||
Some(_block_allocator) => {
|
||||
prefill_tokens += entry.request.input_length;
|
||||
let max_new_tokens = match self.window_size {
|
||||
None => entry.request.stopping_parameters.max_new_tokens,
|
||||
Some(window_size) => min(
|
||||
window_size.saturating_sub(entry.request.input_length),
|
||||
entry.request.stopping_parameters.max_new_tokens,
|
||||
),
|
||||
};
|
||||
decode_tokens += max_new_tokens;
|
||||
|
||||
if prefill_tokens > prefill_token_budget
|
||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||
{
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
self.entries.push_front((id, entry));
|
||||
break;
|
||||
}
|
||||
|
||||
let tokens = entry.request.input_length
|
||||
+ entry.request.stopping_parameters.max_new_tokens
|
||||
+ self.speculate
|
||||
- 1;
|
||||
|
||||
Some(block_allocator) => {
|
||||
// If users wants the prefill logprobs, we cannot reuse the cache.
|
||||
// So no input_ids for the radix tree.
|
||||
let input_ids = if entry.request.decoder_input_details {
|
||||
@ -321,10 +304,65 @@ impl State {
|
||||
entry.request.input_ids.clone()
|
||||
};
|
||||
|
||||
Some((tokens, input_ids))
|
||||
let tokens = entry.request.input_length
|
||||
+ entry.request.stopping_parameters.max_new_tokens
|
||||
+ self.speculate
|
||||
- 1;
|
||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||
|
||||
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: not enough free blocks");
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
Some(mut block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
|
||||
if block_allocation.prefix_len == entry.request.input_length {
|
||||
// The whole request was found in the radix trie
|
||||
// However, for the transformer forward to work, we need to
|
||||
// have at least one token of postfix.
|
||||
block_allocation.prefix_len -= 1;
|
||||
}
|
||||
|
||||
block_allocation
|
||||
}
|
||||
};
|
||||
batch.push((id, entry, block_allocation));
|
||||
|
||||
let mut postfix_len = entry.request.input_length - block_allocation.prefix_len;
|
||||
|
||||
// Check equality too as if we don't we might end up with a postfix_len = 0
|
||||
// in the next iteration of the loop
|
||||
if prefill_tokens + postfix_len >= prefill_token_budget {
|
||||
// Entry is over budget
|
||||
if self.support_chunking {
|
||||
// We support chunking, just set postfix_len to exactly match prefill_token_budget
|
||||
postfix_len = prefill_token_budget - prefill_tokens;
|
||||
// Push this entry inside the batch
|
||||
batch.push((id, entry, Some(block_allocation), postfix_len));
|
||||
break 'entry_loop;
|
||||
} else {
|
||||
// We don't support chunking, this entry needs to go back to the buffer
|
||||
// Add it back to the front
|
||||
tracing::debug!(
|
||||
"Over budget: prefill_tokens={} > {prefill_token_budget}",
|
||||
prefill_tokens + postfix_len
|
||||
);
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
}
|
||||
|
||||
prefill_tokens += postfix_len;
|
||||
|
||||
(Some(block_allocation), postfix_len)
|
||||
}
|
||||
};
|
||||
batch.push((id, entry, block_allocation, postfix_len));
|
||||
if Some(batch.len()) == max_size {
|
||||
break;
|
||||
}
|
||||
@ -342,7 +380,7 @@ impl State {
|
||||
// Batch is too small
|
||||
if batch.len() < min_size {
|
||||
// Add back entries to the queue in the correct order
|
||||
for (id, entry, _) in batch.into_iter().rev() {
|
||||
for (id, entry, _, _) in batch.into_iter().rev() {
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
return None;
|
||||
@ -353,29 +391,7 @@ impl State {
|
||||
let mut batch_entries =
|
||||
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
for (id, mut entry, block_allocation) in batch {
|
||||
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
|
||||
(block_allocation, &self.block_allocator)
|
||||
{
|
||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||
match block_allocator.allocate(tokens, input_ids).await {
|
||||
None => {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: not enough free blocks");
|
||||
self.entries.push_front((id, entry));
|
||||
continue;
|
||||
}
|
||||
Some(block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
Some(block_allocation)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
tracing::debug!("Accepting entry");
|
||||
for (id, mut entry, block_allocation, postfix_len) in batch {
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
@ -429,6 +445,7 @@ impl State {
|
||||
slots,
|
||||
prefix_len,
|
||||
adapter_id: entry.request.adapter_id.clone(),
|
||||
postfix_len,
|
||||
});
|
||||
// Set batch_time
|
||||
entry.batch_time = Some(Instant::now());
|
||||
@ -436,12 +453,6 @@ impl State {
|
||||
batch_entries.insert(id, entry);
|
||||
}
|
||||
|
||||
// Empty batch
|
||||
if batch_requests.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Final batch size
|
||||
let size = batch_requests.len() as u32;
|
||||
next_batch_span.record("batch_size", size);
|
||||
|
@ -159,6 +159,7 @@ async fn prefill(
|
||||
blocks: vec![],
|
||||
slots: vec![],
|
||||
prefix_len: 0,
|
||||
postfix_len: sequence_length,
|
||||
adapter_id: None,
|
||||
})
|
||||
.collect();
|
||||
|
@ -34,6 +34,7 @@ message InfoResponse {
|
||||
string device_type = 3;
|
||||
optional uint32 window_size = 4;
|
||||
uint32 speculate = 5;
|
||||
bool support_chunking = 6;
|
||||
}
|
||||
|
||||
/// Empty request
|
||||
@ -139,6 +140,8 @@ message Request {
|
||||
uint32 prefix_len = 12;
|
||||
/// Context truncation
|
||||
bool add_special_tokens = 13;
|
||||
/// Postfix length for prefill chunking
|
||||
uint32 postfix_len = 14;
|
||||
}
|
||||
|
||||
message Batch {
|
||||
@ -163,6 +166,8 @@ message CachedBatch {
|
||||
uint32 size = 3;
|
||||
/// Maximum number of tokens this batch will grow to
|
||||
uint32 max_tokens = 4;
|
||||
/// Number of tokens in the next forward
|
||||
uint32 current_tokens = 5;
|
||||
}
|
||||
|
||||
enum FinishReason {
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import os
|
||||
from text_generation_server.pb import generate_pb2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_pb_parameters():
|
||||
return generate_pb2.NextTokenChooserParameters(
|
||||
|
@ -76,6 +76,7 @@ class CausalLMBatch(Batch):
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
current_tokens=len(self),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -16,7 +16,17 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
)
|
||||
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict, Union
|
||||
from typing import (
|
||||
Any,
|
||||
ContextManager,
|
||||
Iterable,
|
||||
Optional,
|
||||
Tuple,
|
||||
List,
|
||||
Type,
|
||||
Dict,
|
||||
Union,
|
||||
)
|
||||
|
||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
@ -24,6 +34,10 @@ from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.utils.prefill_chunking import (
|
||||
get_support_chunking,
|
||||
get_max_prefill_tokens,
|
||||
)
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.utils import (
|
||||
@ -60,12 +74,9 @@ from text_generation_server.utils.import_utils import (
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
# Will be set in init
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
|
||||
TOKEN_BUDGET = 8
|
||||
|
||||
|
||||
def set_sliding_window(sliding_window: int):
|
||||
global SLIDING_WINDOW
|
||||
@ -206,6 +217,11 @@ class FlashCausalLMBatch(Batch):
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.num_blocks * BLOCK_SIZE,
|
||||
current_tokens=(
|
||||
sum([len(i) for i in self.input_ids])
|
||||
if isinstance(self.input_ids, list)
|
||||
else len(self.input_ids)
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -272,7 +288,7 @@ class FlashCausalLMBatch(Batch):
|
||||
prompt_lengths.append(prompt_length)
|
||||
|
||||
prefix_length = r.prefix_len
|
||||
postfix_length = prefix_length + 10
|
||||
postfix_length = r.postfix_len
|
||||
assert (
|
||||
prefix_length <= prompt_length
|
||||
), f"Prefix {prefix_length} vs input {prompt_length}"
|
||||
@ -282,10 +298,13 @@ class FlashCausalLMBatch(Batch):
|
||||
if prefix_length + postfix_length < prompt_length:
|
||||
# FIXME: speculate is not supported for context chunking at the moment
|
||||
assert speculate == 0
|
||||
assert get_support_chunking()
|
||||
assert postfix_length > 0
|
||||
|
||||
prefix_ids.append(tokenized_input[:prefix_length])
|
||||
postfix_ids = tokenized_input[prefix_length : postfix_length]
|
||||
# postfix_ids = tokenized_input[prefix_length:]
|
||||
postfix_ids = tokenized_input[
|
||||
prefix_length : prefix_length + postfix_length
|
||||
]
|
||||
|
||||
postfix_length = len(postfix_ids)
|
||||
postfix_lengths.append(postfix_length)
|
||||
@ -371,7 +390,6 @@ class FlashCausalLMBatch(Batch):
|
||||
requests=pb.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=all_postfix_ids,
|
||||
|
||||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
prefix_lengths=prefix_lengths,
|
||||
@ -395,7 +413,6 @@ class FlashCausalLMBatch(Batch):
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=None,
|
||||
prompt_lengths_tensor=prompt_lengths_tensor,
|
||||
|
||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||
position_ids=None,
|
||||
cu_seqlen_prefill=None,
|
||||
@ -431,7 +448,7 @@ class FlashCausalLMBatch(Batch):
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
|
||||
device = self.input_ids.device
|
||||
device = self.block_tables_tensor.device
|
||||
|
||||
# New values after filtering
|
||||
requests_idx_mapping = {}
|
||||
@ -643,24 +660,24 @@ class FlashCausalLMBatch(Batch):
|
||||
max_current_length = 0
|
||||
for b in batches:
|
||||
total_batch_size += len(b)
|
||||
total_slots += len(b.slots)
|
||||
max_blocks = max(max_blocks, b.max_blocks)
|
||||
# If `b` is prefilling and was just filtered, `b.slots` is None
|
||||
# `total_slots` is not used if any of the batches is prefilling
|
||||
total_slots += len(b.slots) if not b.prefilling else 0
|
||||
num_blocks += b.num_blocks
|
||||
speculative_length = (
|
||||
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||
)
|
||||
max_blocks = max(max_blocks, b.max_blocks)
|
||||
max_postfix_length = max(max_postfix_length, b.max_postfix_length)
|
||||
max_current_length = max(max_current_length, b.max_current_length)
|
||||
max_length = max(
|
||||
max_length,
|
||||
max(
|
||||
prefix_length
|
||||
+ postfix_length
|
||||
prompt_length
|
||||
+ stopping_criteria.max_new_tokens
|
||||
+ speculative_length
|
||||
- stopping_criteria.current_tokens
|
||||
for prefix_length, postfix_length, stopping_criteria in zip(
|
||||
b.prefix_lengths, b.postfix_lengths, b.stopping_criterias
|
||||
for prompt_length, stopping_criteria in zip(
|
||||
b.prompt_lengths, b.stopping_criterias
|
||||
)
|
||||
),
|
||||
)
|
||||
@ -746,8 +763,6 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
start_index = cumulative_batch_size
|
||||
end_index = cumulative_batch_size + len(batch)
|
||||
slots_start_index = cumulative_slots
|
||||
slots_end_index = cumulative_slots + len(batch.slots)
|
||||
|
||||
# Copy tensors (GPU)
|
||||
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
|
||||
@ -761,10 +776,17 @@ class FlashCausalLMBatch(Batch):
|
||||
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
||||
|
||||
if not prefilling:
|
||||
slots_start_index = cumulative_slots
|
||||
slots_end_index = cumulative_slots + len(batch.slots)
|
||||
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
|
||||
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
|
||||
slot_indices[start_index:end_index] = (
|
||||
batch.slot_indices + cumulative_slots
|
||||
)
|
||||
postfix_lengths_tensor[start_index:end_index] = (
|
||||
batch.postfix_lengths_tensor
|
||||
)
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
|
||||
# Copy over adapter indices
|
||||
@ -779,11 +801,17 @@ class FlashCausalLMBatch(Batch):
|
||||
cumulative_adapter_indices_size = adapter_end_index
|
||||
adapter_set.update(batch.adapter_meta.adapter_set)
|
||||
adapter_segment_builder.concat(
|
||||
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
|
||||
batch.adapter_meta.adapter_segments,
|
||||
batch.adapter_meta.segment_indices,
|
||||
)
|
||||
prefix_lengths_tensor[start_index:end_index] = (
|
||||
batch.prefix_lengths_tensor
|
||||
)
|
||||
prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor
|
||||
|
||||
start_slots.append(batch.start_slots + cumulative_slots)
|
||||
|
||||
# Update
|
||||
cumulative_slots += len(batch.slots)
|
||||
else:
|
||||
if isinstance(batch.input_ids, torch.Tensor):
|
||||
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
|
||||
@ -810,7 +838,6 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Update
|
||||
cumulative_batch_size += len(batch)
|
||||
cumulative_slots += len(batch.slots)
|
||||
|
||||
if start_slots is not None:
|
||||
start_slots = torch.concat(start_slots)
|
||||
@ -915,7 +942,7 @@ class FlashCausalLMBatch(Batch):
|
||||
postfix_length,
|
||||
prompt_length,
|
||||
request_prefilling,
|
||||
blocks
|
||||
blocks,
|
||||
) in enumerate(
|
||||
zip(
|
||||
self.requests,
|
||||
@ -923,7 +950,7 @@ class FlashCausalLMBatch(Batch):
|
||||
self.postfix_lengths,
|
||||
self.prompt_lengths,
|
||||
self.prefilling_mask,
|
||||
self.block_tables
|
||||
self.block_tables,
|
||||
)
|
||||
):
|
||||
next_chunk_length = postfix_length
|
||||
@ -967,9 +994,7 @@ class FlashCausalLMBatch(Batch):
|
||||
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
|
||||
|
||||
if prefill_logprobs:
|
||||
prefill_head_indices.append(
|
||||
request_position_ids + cumulative_length
|
||||
)
|
||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||
prefill_next_token_indices.append(
|
||||
prefill_out_cumulative_length + postfix_length - 1
|
||||
)
|
||||
@ -988,7 +1013,6 @@ class FlashCausalLMBatch(Batch):
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
|
||||
start_slots.append(cumulative_slot_tokens)
|
||||
slots.extend(request_slots)
|
||||
slot_indices.append(request_slot_indices)
|
||||
@ -998,9 +1022,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||
adapter_indices_list.append(
|
||||
torch.full((next_chunk_length,), adapter_index)
|
||||
)
|
||||
adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index))
|
||||
adapter_set.add(adapter_index)
|
||||
|
||||
# Update
|
||||
@ -1240,6 +1262,7 @@ class FlashCausalLM(Model):
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
sliding_window=config.sliding_window,
|
||||
support_chunking=True,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -1764,13 +1787,18 @@ class FlashCausalLM(Model):
|
||||
finished_prefilling = True
|
||||
next_chunk_lengths = []
|
||||
if prefill:
|
||||
if get_support_chunking():
|
||||
next_prefilling_mask = []
|
||||
# Budget in tokens for the next batch
|
||||
# We remove next input ids to always have enough space for at least a single decode
|
||||
# We remove len(batch) to always have enough space for at least a single decode
|
||||
# for the remaining requests
|
||||
batch_budget = TOKEN_BUDGET - len(batch)
|
||||
batch_budget = get_max_prefill_tokens() - len(batch)
|
||||
# We reverse to prioritize older requests
|
||||
# zip() is not reversible so reverse the underlying lists instead
|
||||
for prefix_length, postfix_length, prompt_length in zip(
|
||||
batch.prefix_lengths, batch.postfix_lengths, batch.prompt_lengths
|
||||
reversed(batch.prefix_lengths),
|
||||
reversed(batch.postfix_lengths),
|
||||
reversed(batch.prompt_lengths),
|
||||
):
|
||||
remaining_prefill_tokens = max(
|
||||
prompt_length - prefix_length - postfix_length, 0
|
||||
@ -1788,6 +1816,15 @@ class FlashCausalLM(Model):
|
||||
next_prefilling_mask.append(False)
|
||||
next_chunk_lengths.append(next_chunk_length)
|
||||
|
||||
# Reverse back the obtained values²
|
||||
next_chunk_lengths.reverse()
|
||||
next_prefilling_mask.reverse()
|
||||
else:
|
||||
# The model does not support chunking
|
||||
# We know we only do a single prefill
|
||||
finished_prefilling = True
|
||||
next_prefilling_mask = [False] * len(batch)
|
||||
|
||||
batch.prefilling = not finished_prefilling
|
||||
batch.prefilling_mask = next_prefilling_mask
|
||||
|
||||
@ -2179,7 +2216,9 @@ class FlashCausalLM(Model):
|
||||
# have more than one new token per request with speculative decoding
|
||||
for next_token_id in _next_token_ids:
|
||||
batch.next_token_chooser = (
|
||||
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
|
||||
batch.next_token_chooser.advance_grammar_single(
|
||||
i, next_token_id
|
||||
)
|
||||
)
|
||||
|
||||
# Update values
|
||||
|
@ -18,7 +18,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
|
||||
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
|
||||
assert TGI_WIGGLE_ROOM > 0
|
||||
assert TGI_WIGGLE_ROOM < 1
|
||||
|
||||
|
@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
current_tokens=len(self),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -116,6 +116,7 @@ class MambaBatch(Batch):
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
current_tokens=len(self),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -5,8 +5,11 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict
|
||||
from collections import defaultdict
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.models.types import Batch, Generation
|
||||
from text_generation_server.utils.log import log_master
|
||||
from text_generation_server.utils.prefill_chunking import set_support_chunking
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
from text_generation_server.adapters.weights import LayerAdapterWeights
|
||||
@ -31,6 +34,7 @@ class Model(ABC):
|
||||
sliding_window: Optional[int] = None,
|
||||
speculate: Optional[int] = None,
|
||||
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
||||
support_chunking: bool = False,
|
||||
):
|
||||
self.model_id = model_id
|
||||
self.model = model.eval()
|
||||
@ -60,6 +64,17 @@ class Model(ABC):
|
||||
speculate = get_speculate()
|
||||
self.speculate = speculate
|
||||
|
||||
if speculate != 0 and support_chunking:
|
||||
log_master(
|
||||
logger.warning,
|
||||
"Prefill chunking does not support speculation yet. "
|
||||
"Prefill chunking will be turned off",
|
||||
)
|
||||
support_chunking = False
|
||||
|
||||
self.support_chunking = support_chunking
|
||||
set_support_chunking(support_chunking)
|
||||
|
||||
self.has_position_ids = (
|
||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||
is not None
|
||||
@ -78,6 +93,7 @@ class Model(ABC):
|
||||
device_type=self.device.type,
|
||||
window_size=self.sliding_window,
|
||||
speculate=self.speculate,
|
||||
support_chunking=self.support_chunking,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch):
|
||||
request_ids=[r.id for r in self.requests],
|
||||
size=len(self),
|
||||
max_tokens=self.max_tokens,
|
||||
current_tokens=len(self),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -357,7 +357,6 @@ class VlmCausalLM(FlashCausalLM):
|
||||
else:
|
||||
cuda_graph = None
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
input_lengths = postfix_lengths + prefix_lengths_tensor
|
||||
if PREFIX_CACHING:
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
|
@ -15,6 +15,7 @@ from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||
from text_generation_server.utils.adapter import AdapterInfo
|
||||
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
|
||||
|
||||
try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
@ -96,6 +97,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
set_max_prefill_tokens(request.max_prefill_tokens)
|
||||
|
||||
if self.quantize in {"exl2", "gptq"}:
|
||||
try:
|
||||
# When using GPTQ, Exllama kernels need some global kernels
|
||||
|
24
server/text_generation_server/utils/prefill_chunking.py
Normal file
24
server/text_generation_server/utils/prefill_chunking.py
Normal file
@ -0,0 +1,24 @@
|
||||
from typing import Optional
|
||||
|
||||
SUPPORT_CHUNKING: Optional[bool] = None
|
||||
MAX_PREFILL_TOKENS: Optional[int] = None
|
||||
|
||||
|
||||
def set_support_chunking(support_chunking: bool):
|
||||
global SUPPORT_CHUNKING
|
||||
SUPPORT_CHUNKING = support_chunking
|
||||
|
||||
|
||||
def get_support_chunking() -> bool:
|
||||
global SUPPORT_CHUNKING
|
||||
return SUPPORT_CHUNKING
|
||||
|
||||
|
||||
def set_max_prefill_tokens(max_prefill_tokens: int):
|
||||
global MAX_PREFILL_TOKENS
|
||||
MAX_PREFILL_TOKENS = max_prefill_tokens
|
||||
|
||||
|
||||
def get_max_prefill_tokens() -> int:
|
||||
global MAX_PREFILL_TOKENS
|
||||
return MAX_PREFILL_TOKENS
|
Loading…
Reference in New Issue
Block a user