load tested

This commit is contained in:
OlivierDehaene 2024-10-02 12:59:44 +02:00
parent 34f5dc525e
commit 7f9abde3f8
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
22 changed files with 307 additions and 195 deletions

View File

@ -159,6 +159,7 @@ impl Client {
blocks: vec![],
slots: vec![],
prefix_len: 0,
postfix_len: truncate,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,

View File

@ -246,6 +246,7 @@ impl Health for ShardedClient {
blocks: vec![0],
slots: (0..16).collect(),
prefix_len: 0,
postfix_len: 1,
adapter_id: None,
};
let batch = Batch {

View File

@ -34,9 +34,13 @@ impl BackendV3 {
requires_padding: bool,
window_size: Option<u32>,
speculate: u32,
support_chunking: bool,
) -> Self {
let prefix_caching =
std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
if support_chunking {
tracing::warn!("Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.");
}
let prefix_caching = std::env::var("USE_PREFIX_CACHING").unwrap_or("1".to_string());
let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1");
let attention: String = std::env::var("ATTENTION").unwrap_or("flashinfer".to_string());
@ -52,6 +56,7 @@ impl BackendV3 {
window_size,
speculate,
max_batch_total_tokens,
support_chunking,
);
let batching_task_notifier = Arc::new(Notify::new());
@ -63,6 +68,7 @@ impl BackendV3 {
max_batch_total_tokens,
max_waiting_tokens,
max_batch_size,
support_chunking,
queue.clone(),
batching_task_notifier.clone(),
));
@ -127,6 +133,7 @@ pub(crate) async fn batching_task(
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_batch_size: Option<usize>,
support_chunking: bool,
queue: Queue,
notifier: Arc<Notify>,
) {
@ -158,28 +165,44 @@ pub(crate) async fn batching_task(
// Get current batch info
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let current_tokens = batch.current_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
} else {
// Minimum batch size
// TODO: temporarily disable to avoid incorrect deallocation +
// reallocation when using prefix caching.
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
let (min_size, max_size, prefill_token_budget) = if support_chunking {
// Since the next batch will be concatenated with the current batch,
// the current batch tokens must be subtracted to the prefill budget
// In the future, we could concatenate beforehand
let prefill_token_budget = max_batch_prefill_tokens - current_tokens;
// We can ignore min_size and max_size
// Models than rely on max_size cannot support chunking
// Regarding min_size, chunking allow us to consistently run at the compute
// bound, making min_size useless.
(None, None, prefill_token_budget)
} else {
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
} else {
// Minimum batch size
// TODO: temporarily disable to avoid incorrect deallocation +
// reallocation when using prefix caching.
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};
let max_size =
max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));
(min_size, max_size, max_batch_prefill_tokens)
};
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
.await
{
// Tracking metrics

View File

@ -159,6 +159,7 @@ impl Client {
blocks: vec![],
slots: vec![],
prefix_len: 0,
postfix_len: truncate,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,

View File

@ -29,15 +29,6 @@ pub trait Health {
async fn model_health(&self) -> Result<()>;
}
#[derive(Debug)]
pub struct ShardInfo {
pub requires_padding: bool,
pub dtype: String,
pub device_type: String,
pub window_size: Option<u32>,
pub speculate: u32,
}
#[derive(Error, Debug, Clone)]
pub enum ClientError {
#[error("Could not connect to Text Generation server: {0}")]

View File

@ -1,6 +1,6 @@
use crate::client::{ClientError, Result};
use crate::client::Health;
/// Multi shard Client
use crate::client::{Health, ShardInfo};
use crate::client::{ClientError, Result};
use crate::client::grpc_client::{DecodeTimings, PrefillTimings};
use crate::client::{
@ -49,13 +49,13 @@ impl ShardedClient {
/// Get the model info
#[instrument(skip(self))]
pub async fn info(&mut self) -> Result<ShardInfo> {
pub async fn info(&mut self) -> Result<InfoResponse> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| client.info())
.collect();
join_all(futures).await.pop().unwrap().map(ShardInfo::from)
join_all(futures).await.pop().unwrap()
}
/// GRPC health check
@ -194,18 +194,6 @@ impl ShardedClient {
}
}
impl From<InfoResponse> for ShardInfo {
fn from(value: InfoResponse) -> Self {
Self {
requires_padding: value.requires_padding,
dtype: value.dtype,
device_type: value.device_type,
window_size: value.window_size,
speculate: value.speculate,
}
}
}
#[async_trait]
impl Health for ShardedClient {
async fn device_health(&self) -> Result<()> {
@ -248,6 +236,7 @@ impl Health for ShardedClient {
slots: (0..16).collect(),
prefix_len: 0,
adapter_id: None,
postfix_len: 1,
};
let batch = Batch {
id: u64::MAX,

View File

@ -29,6 +29,8 @@ pub struct BackendInfo {
pub max_waiting_tokens: usize,
#[schema(nullable = true, example = "null")]
pub max_batch_size: Option<usize>,
#[schema(example = "false")]
pub support_chunking: bool,
}
#[allow(clippy::too_many_arguments)]
@ -110,6 +112,7 @@ pub async fn connect_backend(
model_device_type: shard_info.device_type.clone(),
model_dtype: shard_info.dtype.clone(),
speculate: shard_info.speculate as usize,
support_chunking: shard_info.support_chunking,
};
let backend = BackendV3::new(
@ -122,6 +125,7 @@ pub async fn connect_backend(
shard_info.requires_padding,
shard_info.window_size,
shard_info.speculate,
shard_info.support_chunking,
);
tracing::info!("Using backend V3");

View File

@ -131,25 +131,12 @@ async fn main() -> Result<(), RouterError> {
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}
if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}
if let Some(max_batch_size) = max_batch_size {
if max_batch_size == 0 {
return Err(RouterError::ArgumentValidation(
@ -158,7 +145,7 @@ async fn main() -> Result<(), RouterError> {
}
}
let (backend, _backend_info) = connect_backend(
let (backend, backend_info) = connect_backend(
max_input_tokens,
max_total_tokens,
master_shard_uds_path,
@ -170,6 +157,19 @@ async fn main() -> Result<(), RouterError> {
)
.await?;
// Validate remaining args now that the backend is known
let support_chunking = backend_info.support_chunking;
let max_batch_total_tokens = backend_info.max_batch_total_tokens;
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
}
if max_batch_prefill_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
// Run server
server::run(
backend,

View File

@ -4,7 +4,7 @@ use crate::client::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::{max, min};
use std::cmp::max;
use std::collections::VecDeque;
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
@ -50,6 +50,7 @@ impl Queue {
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
support_chunking: bool,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
@ -62,6 +63,7 @@ impl Queue {
window_size,
speculate,
max_batch_total_tokens,
support_chunking,
queue_receiver,
));
@ -108,6 +110,7 @@ impl Queue {
}
// Background task responsible of the queue state
#[allow(clippy::too_many_arguments)]
async fn queue_task(
requires_padding: bool,
block_size: u32,
@ -115,6 +118,7 @@ async fn queue_task(
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
support_chunking: bool,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
let mut state = State::new(
@ -124,6 +128,7 @@ async fn queue_task(
window_size,
speculate,
max_batch_total_tokens,
support_chunking,
);
while let Some(cmd) = receiver.recv().await {
@ -166,12 +171,14 @@ struct State {
/// Paged Attention block size
block_size: u32,
/// Sliding window
window_size: Option<u32>,
/// Speculation amount
speculate: u32,
/// Whether the model allow the prefill chunking
/// If it does, the last request in the batch will be split to exactly match the prefill
/// token budget
support_chunking: bool,
/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,
}
@ -184,6 +191,7 @@ impl State {
window_size: Option<u32>,
speculate: u32,
max_batch_total_tokens: u32,
support_chunking: bool,
) -> Self {
let block_allocator = (!requires_padding).then(|| {
BlockAllocator::new(
@ -199,8 +207,8 @@ impl State {
next_id: 0,
next_batch_id: 0,
block_size,
window_size,
speculate,
support_chunking,
block_allocator,
}
}
@ -268,7 +276,7 @@ impl State {
continue;
}
let block_allocation = match &self.block_allocator {
let (block_allocation, postfix_len) = match &self.block_allocator {
None => {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
@ -285,34 +293,9 @@ impl State {
self.entries.push_front((id, entry));
break 'entry_loop;
}
None
(None, entry.request.input_length)
}
Some(_block_allocator) => {
prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
{
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break;
}
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
Some(block_allocator) => {
// If users wants the prefill logprobs, we cannot reuse the cache.
// So no input_ids for the radix tree.
let input_ids = if entry.request.decoder_input_details {
@ -321,10 +304,65 @@ impl State {
entry.request.input_ids.clone()
};
Some((tokens, input_ids))
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
tracing::debug!("Allocating {tokens} with {input_ids:?}");
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(mut block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
if block_allocation.prefix_len == entry.request.input_length {
// The whole request was found in the radix trie
// However, for the transformer forward to work, we need to
// have at least one token of postfix.
block_allocation.prefix_len -= 1;
}
block_allocation
}
};
let mut postfix_len = entry.request.input_length - block_allocation.prefix_len;
// Check equality too as if we don't we might end up with a postfix_len = 0
// in the next iteration of the loop
if prefill_tokens + postfix_len >= prefill_token_budget {
// Entry is over budget
if self.support_chunking {
// We support chunking, just set postfix_len to exactly match prefill_token_budget
postfix_len = prefill_token_budget - prefill_tokens;
// Push this entry inside the batch
batch.push((id, entry, Some(block_allocation), postfix_len));
break 'entry_loop;
} else {
// We don't support chunking, this entry needs to go back to the buffer
// Add it back to the front
tracing::debug!(
"Over budget: prefill_tokens={} > {prefill_token_budget}",
prefill_tokens + postfix_len
);
self.entries.push_front((id, entry));
break 'entry_loop;
}
}
prefill_tokens += postfix_len;
(Some(block_allocation), postfix_len)
}
};
batch.push((id, entry, block_allocation));
batch.push((id, entry, block_allocation, postfix_len));
if Some(batch.len()) == max_size {
break;
}
@ -342,7 +380,7 @@ impl State {
// Batch is too small
if batch.len() < min_size {
// Add back entries to the queue in the correct order
for (id, entry, _) in batch.into_iter().rev() {
for (id, entry, _, _) in batch.into_iter().rev() {
self.entries.push_front((id, entry));
}
return None;
@ -353,29 +391,7 @@ impl State {
let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());
for (id, mut entry, block_allocation) in batch {
let block_allocation = if let (Some((tokens, input_ids)), Some(block_allocator)) =
(block_allocation, &self.block_allocator)
{
tracing::debug!("Allocating {tokens} with {input_ids:?}");
match block_allocator.allocate(tokens, input_ids).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
continue;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
}
}
} else {
None
};
tracing::debug!("Accepting entry");
for (id, mut entry, block_allocation, postfix_len) in batch {
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");
// Add relationships
@ -429,6 +445,7 @@ impl State {
slots,
prefix_len,
adapter_id: entry.request.adapter_id.clone(),
postfix_len,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
@ -436,12 +453,6 @@ impl State {
batch_entries.insert(id, entry);
}
// Empty batch
if batch_requests.is_empty() {
tracing::debug!("Filterered out all entries");
return None;
}
// Final batch size
let size = batch_requests.len() as u32;
next_batch_span.record("batch_size", size);

View File

@ -159,6 +159,7 @@ async fn prefill(
blocks: vec![],
slots: vec![],
prefix_len: 0,
postfix_len: sequence_length,
adapter_id: None,
})
.collect();

View File

@ -34,6 +34,7 @@ message InfoResponse {
string device_type = 3;
optional uint32 window_size = 4;
uint32 speculate = 5;
bool support_chunking = 6;
}
/// Empty request
@ -139,6 +140,8 @@ message Request {
uint32 prefix_len = 12;
/// Context truncation
bool add_special_tokens = 13;
/// Postfix length for prefill chunking
uint32 postfix_len = 14;
}
message Batch {
@ -163,6 +166,8 @@ message CachedBatch {
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Number of tokens in the next forward
uint32 current_tokens = 5;
}
enum FinishReason {

View File

@ -1,7 +1,7 @@
import pytest
import os
from text_generation_server.pb import generate_pb2
@pytest.fixture
def default_pb_parameters():
return generate_pb2.NextTokenChooserParameters(

View File

@ -76,6 +76,7 @@ class CausalLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self),
)
@classmethod

View File

@ -16,7 +16,17 @@ from transformers import (
AutoTokenizer,
GenerationConfig,
)
from typing import Any, ContextManager, Iterable, Optional, Tuple, List, Type, Dict, Union
from typing import (
Any,
ContextManager,
Iterable,
Optional,
Tuple,
List,
Type,
Dict,
Union,
)
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
@ -24,6 +34,10 @@ from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import (
get_support_chunking,
get_max_prefill_tokens,
)
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
@ -60,12 +74,9 @@ from text_generation_server.utils.import_utils import (
tracer = trace.get_tracer(__name__)
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
TOKEN_BUDGET = 8
def set_sliding_window(sliding_window: int):
global SLIDING_WINDOW
@ -206,6 +217,11 @@ class FlashCausalLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.num_blocks * BLOCK_SIZE,
current_tokens=(
sum([len(i) for i in self.input_ids])
if isinstance(self.input_ids, list)
else len(self.input_ids)
),
)
@classmethod
@ -272,7 +288,7 @@ class FlashCausalLMBatch(Batch):
prompt_lengths.append(prompt_length)
prefix_length = r.prefix_len
postfix_length = prefix_length + 10
postfix_length = r.postfix_len
assert (
prefix_length <= prompt_length
), f"Prefix {prefix_length} vs input {prompt_length}"
@ -282,10 +298,13 @@ class FlashCausalLMBatch(Batch):
if prefix_length + postfix_length < prompt_length:
# FIXME: speculate is not supported for context chunking at the moment
assert speculate == 0
assert get_support_chunking()
assert postfix_length > 0
prefix_ids.append(tokenized_input[:prefix_length])
postfix_ids = tokenized_input[prefix_length : postfix_length]
# postfix_ids = tokenized_input[prefix_length:]
postfix_ids = tokenized_input[
prefix_length : prefix_length + postfix_length
]
postfix_length = len(postfix_ids)
postfix_lengths.append(postfix_length)
@ -371,7 +390,6 @@ class FlashCausalLMBatch(Batch):
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=all_postfix_ids,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
prefix_lengths=prefix_lengths,
@ -395,7 +413,6 @@ class FlashCausalLMBatch(Batch):
max_blocks=max_blocks,
speculative_ids=None,
prompt_lengths_tensor=prompt_lengths_tensor,
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids=None,
cu_seqlen_prefill=None,
@ -431,7 +448,7 @@ class FlashCausalLMBatch(Batch):
if len(request_ids) == len(self):
return self
device = self.input_ids.device
device = self.block_tables_tensor.device
# New values after filtering
requests_idx_mapping = {}
@ -552,13 +569,13 @@ class FlashCausalLMBatch(Batch):
if self.prefilling:
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids=None
start_slots=None
slot_indices=None
slots=None
prefix_lengths_tensor=None
postfix_lengths_tensor=None
adapter_meta=None
position_ids = None
start_slots = None
slot_indices = None
slots = None
prefix_lengths_tensor = None
postfix_lengths_tensor = None
adapter_meta = None
else:
# Index into tensors
input_ids = self.input_ids[indices]
@ -643,24 +660,24 @@ class FlashCausalLMBatch(Batch):
max_current_length = 0
for b in batches:
total_batch_size += len(b)
total_slots += len(b.slots)
max_blocks = max(max_blocks, b.max_blocks)
# If `b` is prefilling and was just filtered, `b.slots` is None
# `total_slots` is not used if any of the batches is prefilling
total_slots += len(b.slots) if not b.prefilling else 0
num_blocks += b.num_blocks
speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
)
max_blocks = max(max_blocks, b.max_blocks)
max_postfix_length = max(max_postfix_length, b.max_postfix_length)
max_current_length = max(max_current_length, b.max_current_length)
max_length = max(
max_length,
max(
prefix_length
+ postfix_length
prompt_length
+ stopping_criteria.max_new_tokens
+ speculative_length
- stopping_criteria.current_tokens
for prefix_length, postfix_length, stopping_criteria in zip(
b.prefix_lengths, b.postfix_lengths, b.stopping_criterias
for prompt_length, stopping_criteria in zip(
b.prompt_lengths, b.stopping_criterias
)
),
)
@ -669,14 +686,14 @@ class FlashCausalLMBatch(Batch):
if prefilling:
input_ids = []
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids=None
start_slots=None
slots=None
slot_indices=None
prefix_lengths_tensor=None
postfix_lengths_tensor=None
adapter_meta=None
adapter_segment_builder=None
position_ids = None
start_slots = None
slots = None
slot_indices = None
prefix_lengths_tensor = None
postfix_lengths_tensor = None
adapter_meta = None
adapter_segment_builder = None
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
@ -746,8 +763,6 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
# Copy tensors (GPU)
top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor
@ -761,10 +776,17 @@ class FlashCausalLMBatch(Batch):
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
if not prefilling:
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
postfix_lengths_tensor[start_index:end_index] = batch.postfix_lengths_tensor
slot_indices[start_index:end_index] = (
batch.slot_indices + cumulative_slots
)
postfix_lengths_tensor[start_index:end_index] = (
batch.postfix_lengths_tensor
)
slots[slots_start_index:slots_end_index] = batch.slots
# Copy over adapter indices
@ -779,11 +801,17 @@ class FlashCausalLMBatch(Batch):
cumulative_adapter_indices_size = adapter_end_index
adapter_set.update(batch.adapter_meta.adapter_set)
adapter_segment_builder.concat(
batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices
batch.adapter_meta.adapter_segments,
batch.adapter_meta.segment_indices,
)
prefix_lengths_tensor[start_index:end_index] = (
batch.prefix_lengths_tensor
)
prefix_lengths_tensor[start_index:end_index] = batch.prefix_lengths_tensor
start_slots.append(batch.start_slots + cumulative_slots)
# Update
cumulative_slots += len(batch.slots)
else:
if isinstance(batch.input_ids, torch.Tensor):
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
@ -810,7 +838,6 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_batch_size += len(batch)
cumulative_slots += len(batch.slots)
if start_slots is not None:
start_slots = torch.concat(start_slots)
@ -915,7 +942,7 @@ class FlashCausalLMBatch(Batch):
postfix_length,
prompt_length,
request_prefilling,
blocks
blocks,
) in enumerate(
zip(
self.requests,
@ -923,7 +950,7 @@ class FlashCausalLMBatch(Batch):
self.postfix_lengths,
self.prompt_lengths,
self.prefilling_mask,
self.block_tables
self.block_tables,
)
):
next_chunk_length = postfix_length
@ -967,9 +994,7 @@ class FlashCausalLMBatch(Batch):
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
if prefill_logprobs:
prefill_head_indices.append(
request_position_ids + cumulative_length
)
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append(
prefill_out_cumulative_length + postfix_length - 1
)
@ -988,7 +1013,6 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
start_slots.append(cumulative_slot_tokens)
slots.extend(request_slots)
slot_indices.append(request_slot_indices)
@ -998,9 +1022,7 @@ class FlashCausalLMBatch(Batch):
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(
torch.full((next_chunk_length,), adapter_index)
)
adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index))
adapter_set.add(adapter_index)
# Update
@ -1240,6 +1262,7 @@ class FlashCausalLM(Model):
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
support_chunking=True,
)
@property
@ -1764,29 +1787,43 @@ class FlashCausalLM(Model):
finished_prefilling = True
next_chunk_lengths = []
if prefill:
next_prefilling_mask = []
# Budget in tokens for the next batch
# We remove next input ids to always have enough space for at least a single decode
# for the remaining requests
batch_budget = TOKEN_BUDGET - len(batch)
for prefix_length, postfix_length, prompt_length in zip(
batch.prefix_lengths, batch.postfix_lengths, batch.prompt_lengths
):
remaining_prefill_tokens = max(
prompt_length - prefix_length - postfix_length, 0
)
if remaining_prefill_tokens > 0:
next_chunk_length = max(
min(remaining_prefill_tokens, batch_budget), 1
if get_support_chunking():
next_prefilling_mask = []
# Budget in tokens for the next batch
# We remove len(batch) to always have enough space for at least a single decode
# for the remaining requests
batch_budget = get_max_prefill_tokens() - len(batch)
# We reverse to prioritize older requests
# zip() is not reversible so reverse the underlying lists instead
for prefix_length, postfix_length, prompt_length in zip(
reversed(batch.prefix_lengths),
reversed(batch.postfix_lengths),
reversed(batch.prompt_lengths),
):
remaining_prefill_tokens = max(
prompt_length - prefix_length - postfix_length, 0
)
batch_budget -= next_chunk_length
finished_prefilling = False
next_prefilling_mask.append(True)
else:
# Since speculation will be turned off, this is always true
next_chunk_length = 1
next_prefilling_mask.append(False)
next_chunk_lengths.append(next_chunk_length)
if remaining_prefill_tokens > 0:
next_chunk_length = max(
min(remaining_prefill_tokens, batch_budget), 1
)
batch_budget -= next_chunk_length
finished_prefilling = False
next_prefilling_mask.append(True)
else:
# Since speculation will be turned off, this is always true
next_chunk_length = 1
next_prefilling_mask.append(False)
next_chunk_lengths.append(next_chunk_length)
# Reverse back the obtained values²
next_chunk_lengths.reverse()
next_prefilling_mask.reverse()
else:
# The model does not support chunking
# We know we only do a single prefill
finished_prefilling = True
next_prefilling_mask = [False] * len(batch)
batch.prefilling = not finished_prefilling
batch.prefilling_mask = next_prefilling_mask
@ -2179,7 +2216,9 @@ class FlashCausalLM(Model):
# have more than one new token per request with speculative decoding
for next_token_id in _next_token_ids:
batch.next_token_chooser = (
batch.next_token_chooser.advance_grammar_single(i, next_token_id)
batch.next_token_chooser.advance_grammar_single(
i, next_token_id
)
)
# Update values

View File

@ -18,7 +18,7 @@ if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1

View File

@ -83,6 +83,7 @@ class IdeficsCausalLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self),
)
@classmethod

View File

@ -116,6 +116,7 @@ class MambaBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self),
)
@classmethod

View File

@ -5,8 +5,11 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict
from transformers import PreTrainedTokenizerBase
from loguru import logger
from text_generation_server.models.types import Batch, Generation
from text_generation_server.utils.log import log_master
from text_generation_server.utils.prefill_chunking import set_support_chunking
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.pb.generate_pb2 import InfoResponse
from text_generation_server.adapters.weights import LayerAdapterWeights
@ -31,6 +34,7 @@ class Model(ABC):
sliding_window: Optional[int] = None,
speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID,
support_chunking: bool = False,
):
self.model_id = model_id
self.model = model.eval()
@ -60,6 +64,17 @@ class Model(ABC):
speculate = get_speculate()
self.speculate = speculate
if speculate != 0 and support_chunking:
log_master(
logger.warning,
"Prefill chunking does not support speculation yet. "
"Prefill chunking will be turned off",
)
support_chunking = False
self.support_chunking = support_chunking
set_support_chunking(support_chunking)
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
@ -78,6 +93,7 @@ class Model(ABC):
device_type=self.device.type,
window_size=self.sliding_window,
speculate=self.speculate,
support_chunking=self.support_chunking,
)
@property

View File

@ -80,6 +80,7 @@ class Seq2SeqLMBatch(Batch):
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
current_tokens=len(self),
)
@classmethod

View File

@ -357,7 +357,6 @@ class VlmCausalLM(FlashCausalLM):
else:
cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = postfix_lengths + prefix_lengths_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
@ -424,7 +423,7 @@ class VlmCausalLM(FlashCausalLM):
cuda_graph["postfix_lengths"][: postfix_lengths.shape[0]] = postfix_lengths
cuda_graph["prefix_lengths"].zero_()
cuda_graph["prefix_lengths"][
: prefix_lengths_tensor.shape[0]
: prefix_lengths_tensor.shape[0]
] = prefix_lengths_tensor
with self._forward_context(

View File

@ -15,6 +15,7 @@ from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch
@ -96,6 +97,8 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
set_max_prefill_tokens(request.max_prefill_tokens)
if self.quantize in {"exl2", "gptq"}:
try:
# When using GPTQ, Exllama kernels need some global kernels

View File

@ -0,0 +1,24 @@
from typing import Optional
SUPPORT_CHUNKING: Optional[bool] = None
MAX_PREFILL_TOKENS: Optional[int] = None
def set_support_chunking(support_chunking: bool):
global SUPPORT_CHUNKING
SUPPORT_CHUNKING = support_chunking
def get_support_chunking() -> bool:
global SUPPORT_CHUNKING
return SUPPORT_CHUNKING
def set_max_prefill_tokens(max_prefill_tokens: int):
global MAX_PREFILL_TOKENS
MAX_PREFILL_TOKENS = max_prefill_tokens
def get_max_prefill_tokens() -> int:
global MAX_PREFILL_TOKENS
return MAX_PREFILL_TOKENS