feat: move allocation logic to rust (#1835)

Close #2007
This commit is contained in:
OlivierDehaene 2024-06-05 12:18:38 +02:00 committed by yuanwu
parent cdd120ac02
commit 20df9234a9
24 changed files with 499 additions and 845 deletions

View File

@ -26,7 +26,12 @@ incremental = true
inherits = "release" inherits = "release"
debug = 1 debug = 1
incremental = true incremental = true
panic = "abort"
[profile.release-opt]
inherits = "release"
debug = 0
incremental = false
lto = "fat" lto = "fat"
opt-level = 3 opt-level = 3
codegen-units = 1 codegen-units = 1
panic = "abort"

View File

@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json RUN cargo chef cook --profile release-opt --recipe-path recipe.json
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
@ -33,7 +33,7 @@ COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY launcher launcher COPY launcher launcher
RUN cargo build --release RUN cargo build --profile release-opt
# Text Generation Inference base image for RoCm # Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
@ -193,11 +193,11 @@ RUN cd server && \
pip install ".[accelerate, peft, outlines]" --no-cache-dir pip install ".[accelerate, peft, outlines]" --no-cache-dir
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher # Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# AWS Sagemaker compatible image # AWS Sagemaker compatible image
FROM base as sagemaker FROM base as sagemaker

View File

@ -24,7 +24,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
rm -f $PROTOC_ZIP rm -f $PROTOC_ZIP
COPY --from=planner /usr/src/recipe.json recipe.json COPY --from=planner /usr/src/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json RUN cargo chef cook --profile release-opt --recipe-path recipe.json
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
@ -32,7 +32,7 @@ COPY proto proto
COPY benchmark benchmark COPY benchmark benchmark
COPY router router COPY router router
COPY launcher launcher COPY launcher launcher
RUN cargo build --release RUN cargo build --profile release-opt
# Text Generation Inference base image for Intel # Text Generation Inference base image for Intel
@ -78,11 +78,11 @@ ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mp
ENV CCL_ZE_IPC_EXCHANGE=sockets ENV CCL_ZE_IPC_EXCHANGE=sockets
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
# Install launcher # Install launcher
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# Final image # Final image
FROM base FROM base

View File

@ -155,6 +155,8 @@ async fn prefill(
ignore_eos_token: true, // Will not stop even if a eos token is generated ignore_eos_token: true, // Will not stop even if a eos token is generated
}), }),
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
}) })
.collect(); .collect();
@ -163,6 +165,7 @@ async fn prefill(
requests, requests,
size: batch_size, size: batch_size,
max_tokens: batch_size * (sequence_length + decode_length), max_tokens: batch_size * (sequence_length + decode_length),
max_blocks: 0,
}; };
// Run prefill // Run prefill

View File

@ -130,6 +130,10 @@ message Request {
bool prefill_logprobs = 6; bool prefill_logprobs = 6;
/// Return most likely n tokens /// Return most likely n tokens
uint32 top_n_tokens = 7; uint32 top_n_tokens = 7;
/// Paged attention blocks
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
} }
message Batch { message Batch {
@ -141,6 +145,8 @@ message Batch {
uint32 size = 3; uint32 size = 3;
/// Maximum number of tokens this batch will grow to /// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4; uint32 max_tokens = 4;
/// Maximum number of Paged Attention blocks
uint32 max_blocks = 5;
} }
message CachedBatch { message CachedBatch {

View File

@ -153,6 +153,9 @@ impl Client {
}), }),
// We truncate the input on the server side to be sure that it has the correct size // We truncate the input on the server side to be sure that it has the correct size
truncate, truncate,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
// Set sampling parameters to also take these ops into account in the max memory // Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters { parameters: Some(NextTokenChooserParameters {
temperature: 0.9, temperature: 0.9,
@ -187,7 +190,8 @@ impl Client {
id: 0, id: 0,
size: requests.len() as u32, size: requests.len() as u32,
requests, requests,
max_tokens: 0, max_tokens: max_input_length,
max_blocks: 0,
}; };
let request = tonic::Request::new(WarmupRequest { let request = tonic::Request::new(WarmupRequest {

View File

@ -241,12 +241,16 @@ impl Health for ShardedClient {
ignore_eos_token: false, ignore_eos_token: false,
}), }),
top_n_tokens: 0, top_n_tokens: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
}; };
let batch = Batch { let batch = Batch {
id: u64::MAX, id: u64::MAX,
requests: vec![liveness_request], requests: vec![liveness_request],
size: 1, size: 1,
max_tokens: 2, max_tokens: 2,
max_blocks: 1,
}; };
self.clone().prefill(batch).await?; self.clone().prefill(batch).await?;
Ok(()) Ok(())

View File

@ -0,0 +1,136 @@
use std::cmp::min;
use tokio::sync::{mpsc, oneshot};
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocation {
pub blocks: Vec<u32>,
pub slots: Vec<u32>,
block_allocator: BlockAllocator,
}
impl Drop for BlockAllocation {
fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone())
}
}
#[derive(Debug, Clone)]
pub(crate) struct BlockAllocator {
/// Channel to communicate with the background task
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
}
impl BlockAllocator {
pub(crate) fn new(
max_batch_total_tokens: u32,
block_size: u32,
window_size: Option<u32>,
) -> Self {
// Create channel
let (sender, receiver) = mpsc::unbounded_channel();
// Launch background queue task
tokio::spawn(block_allocator_task(
max_batch_total_tokens / block_size,
block_size,
window_size,
receiver,
));
Self {
block_allocator: sender,
}
}
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
tokens,
response_sender,
})
.unwrap();
response_receiver
.await
.unwrap()
.map(|(blocks, slots)| BlockAllocation {
blocks,
slots,
block_allocator: self.clone(),
})
}
pub(crate) fn free(&self, blocks: Vec<u32>) {
self.block_allocator
.send(BlockAllocatorCommand::Free { blocks })
.unwrap();
}
}
async fn block_allocator_task(
blocks: u32,
block_size: u32,
window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) {
// Block 0 is reserved for health checks
let mut free_blocks: Vec<u32> = (1..blocks).collect();
while let Some(cmd) = receiver.recv().await {
match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
BlockAllocatorCommand::Allocate {
tokens,
response_sender,
} => {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size;
(required_blocks, repeats)
};
let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
None
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);
'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
};
response_sender.send(allocation).unwrap();
}
}
}
}
#[derive(Debug)]
enum BlockAllocatorCommand {
Free {
blocks: Vec<u32>,
},
Allocate {
tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
},
}

View File

@ -1,3 +1,4 @@
mod block_allocator;
mod queue; mod queue;
mod scheduler; mod scheduler;

View File

@ -1,17 +1,20 @@
use crate::infer::{InferError, InferStreamResponse}; use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
use crate::infer::InferError;
use crate::infer::InferStreamResponse;
use crate::validation::{ use crate::validation::{
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
}; };
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::min; use std::cmp::{max, min};
use std::collections::VecDeque; use std::collections::VecDeque;
use text_generation_client::v3::{ use text_generation_client::v3::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
}; };
use text_generation_client::{ChunksToString, Input}; use text_generation_client::ChunksToString;
use text_generation_client::Input;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{info_span, instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
/// Queue entry /// Queue entry
#[derive(Debug)] #[derive(Debug)]
@ -28,6 +31,8 @@ pub(crate) struct Entry {
pub queue_time: Instant, pub queue_time: Instant,
/// Instant when this entry was added to a batch /// Instant when this entry was added to a batch
pub batch_time: Option<Instant>, pub batch_time: Option<Instant>,
/// Block Allocation
pub block_allocation: Option<BlockAllocation>,
} }
/// Request Queue /// Request Queue
@ -43,6 +48,7 @@ impl Queue {
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32,
) -> Self { ) -> Self {
// Create channel // Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel(); let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
@ -53,12 +59,14 @@ impl Queue {
block_size, block_size,
window_size, window_size,
speculate, speculate,
max_batch_total_tokens,
queue_receiver, queue_receiver,
)); ));
Self { queue_sender } Self { queue_sender }
} }
/// Append an entry to the queue
#[instrument(skip_all)] #[instrument(skip_all)]
pub(crate) fn append(&self, entry: Entry) { pub(crate) fn append(&self, entry: Entry) {
// Send append command to the background task managing the state // Send append command to the background task managing the state
@ -103,9 +111,16 @@ async fn queue_task(
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>, mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) { ) {
let mut state = State::new(requires_padding, block_size, window_size, speculate); let mut state = State::new(
requires_padding,
block_size,
window_size,
speculate,
max_batch_total_tokens,
);
while let Some(cmd) = receiver.recv().await { while let Some(cmd) = receiver.recv().await {
match cmd { match cmd {
@ -120,12 +135,14 @@ async fn queue_task(
token_budget, token_budget,
response_sender, response_sender,
span, span,
} => span.in_scope(|| { } => {
let next_batch = let next_batch = state
state.next_batch(min_size, max_size, prefill_token_budget, token_budget); .next_batch(min_size, max_size, prefill_token_budget, token_budget)
.instrument(span)
.await;
response_sender.send(next_batch).unwrap(); response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64); metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}), }
} }
} }
} }
@ -142,9 +159,6 @@ struct State {
/// Id of the next batch /// Id of the next batch
next_batch_id: u64, next_batch_id: u64,
/// Whether the model is using padding
requires_padding: bool,
/// Paged Attention block size /// Paged Attention block size
block_size: u32, block_size: u32,
@ -153,6 +167,9 @@ struct State {
/// Speculation amount /// Speculation amount
speculate: u32, speculate: u32,
/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,
} }
impl State { impl State {
@ -161,15 +178,19 @@ impl State {
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
speculate: u32, speculate: u32,
max_batch_total_tokens: u32,
) -> Self { ) -> Self {
let block_allocator = (!requires_padding)
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
Self { Self {
entries: VecDeque::with_capacity(128), entries: VecDeque::with_capacity(128),
next_id: 0, next_id: 0,
next_batch_id: 0, next_batch_id: 0,
requires_padding,
block_size, block_size,
window_size, window_size,
speculate, speculate,
block_allocator,
} }
} }
@ -185,7 +206,7 @@ impl State {
} }
// Get the next batch // Get the next batch
fn next_batch( async fn next_batch(
&mut self, &mut self,
min_size: Option<usize>, min_size: Option<usize>,
max_size: Option<usize>, max_size: Option<usize>,
@ -220,9 +241,10 @@ impl State {
let mut max_input_length = 0; let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0; let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0; let mut decode_tokens: u32 = 0;
let mut max_blocks = 0;
// Pop entries starting from the front of the queue // Pop entries starting from the front of the queue
while let Some((id, mut entry)) = self.entries.pop_front() { 'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request // Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client) // was dropped by the client)
if entry.response_tx.is_closed() { if entry.response_tx.is_closed() {
@ -231,43 +253,67 @@ impl State {
continue; continue;
} }
if self.requires_padding { let block_allocation = match &self.block_allocator {
// We pad to max input length in the Python shards None => {
// We need to take these padding tokens into the equation // We pad to max input length in the Python shards
max_input_length = max_input_length.max(entry.request.input_length); // We need to take these padding tokens into the equation
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length max_input_length = max_input_length.max(entry.request.input_length);
} else { prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
// pad to block size
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
/ self.block_size)
* self.block_size;
}
if self.requires_padding { decode_tokens += entry.request.stopping_parameters.max_new_tokens;
decode_tokens += entry.request.stopping_parameters.max_new_tokens; let total_tokens = prefill_tokens + decode_tokens + self.speculate;
} else {
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
// pad to block size if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
decode_tokens += // Entry is over budget
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size; // Add it back to the front
} tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry));
break 'entry_loop;
}
None
}
Some(block_allocator) => {
prefill_tokens += entry.request.input_length;
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Some(window_size) => min(
window_size.saturating_sub(entry.request.input_length),
entry.request.stopping_parameters.max_new_tokens,
),
};
decode_tokens += max_new_tokens;
if prefill_tokens > prefill_token_budget if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget || (prefill_tokens + decode_tokens + self.speculate) > token_budget
{ {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
break; break;
} }
let tokens = entry.request.input_length
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
match block_allocator.allocate(tokens).await {
None => {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: not enough free blocks");
self.entries.push_front((id, entry));
break 'entry_loop;
}
Some(block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
Some(block_allocation)
}
}
}
};
tracing::debug!("Accepting entry"); tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry // Create a new span to link the batch back to this entry
@ -278,13 +324,23 @@ impl State {
// Update entry // Update entry
entry.temp_span = Some(entry_batch_span); entry.temp_span = Some(entry_batch_span);
let (blocks, slots) = match &block_allocation {
None => (Vec::new(), Vec::new()),
Some(block_allocation) => (
block_allocation.blocks.clone(),
block_allocation.slots.clone(),
),
};
entry.block_allocation = block_allocation;
batch_requests.push(Request { batch_requests.push(Request {
id, id,
prefill_logprobs: entry.request.decoder_input_details, prefill_logprobs: entry.request.decoder_input_details,
inputs: entry.request.inputs.chunks_to_string(),
input_chunks: Some(Input { input_chunks: Some(Input {
chunks: entry.request.inputs.clone(), chunks: entry.request.inputs.clone(),
}), }),
inputs: entry.request.inputs.chunks_to_string(),
truncate: entry.request.truncate, truncate: entry.request.truncate,
parameters: Some(NextTokenChooserParameters::from( parameters: Some(NextTokenChooserParameters::from(
entry.request.parameters.clone(), entry.request.parameters.clone(),
@ -293,6 +349,8 @@ impl State {
entry.request.stopping_parameters.clone(), entry.request.stopping_parameters.clone(),
)), )),
top_n_tokens: entry.request.top_n_tokens, top_n_tokens: entry.request.top_n_tokens,
blocks,
slots,
}); });
// Set batch_time // Set batch_time
entry.batch_time = Some(Instant::now()); entry.batch_time = Some(Instant::now());
@ -335,6 +393,7 @@ impl State {
requests: batch_requests, requests: batch_requests,
size, size,
max_tokens: (prefill_tokens + decode_tokens), max_tokens: (prefill_tokens + decode_tokens),
max_blocks,
}; };
// Increment batch id // Increment batch id
self.next_batch_id += 1; self.next_batch_id += 1;
@ -438,13 +497,14 @@ mod tests {
temp_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
block_allocation: None,
}; };
(entry, receiver_tx) (entry, receiver_tx)
} }
#[test] #[tokio::test]
fn test_append() { async fn test_append() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(false, 1, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
assert_eq!(state.next_id, 0); assert_eq!(state.next_id, 0);
@ -458,23 +518,23 @@ mod tests {
assert_eq!(id, 0); assert_eq!(id, 0);
} }
#[test] #[tokio::test]
fn test_next_batch_empty() { async fn test_next_batch_empty() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(false, 1, None, 0, 16);
assert!(state.next_batch(None, None, 1, 1).is_none()); assert!(state.next_batch(None, None, 1, 1).await.is_none());
assert!(state.next_batch(Some(1), None, 1, 1).is_none()); assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
} }
#[test] #[tokio::test]
fn test_next_batch_min_size() { async fn test_next_batch_min_size() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
state.append(entry2); state.append(entry2);
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap(); let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
@ -490,7 +550,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
state.append(entry3); state.append(entry3);
assert!(state.next_batch(Some(2), None, 2, 2).is_none()); assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
assert_eq!(state.next_id, 3); assert_eq!(state.next_id, 3);
assert_eq!(state.entries.len(), 1); assert_eq!(state.entries.len(), 1);
@ -498,15 +558,15 @@ mod tests {
assert_eq!(id, 2); assert_eq!(id, 2);
} }
#[test] #[tokio::test]
fn test_next_batch_max_size() { async fn test_next_batch_max_size() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
state.append(entry2); state.append(entry2);
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap(); let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert!(entries.get(&0).unwrap().batch_time.is_some()); assert!(entries.get(&0).unwrap().batch_time.is_some());
@ -518,15 +578,15 @@ mod tests {
assert_eq!(state.next_batch_id, 1); assert_eq!(state.next_batch_id, 1);
} }
#[test] #[tokio::test]
fn test_next_batch_token_budget() { async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0); let mut state = State::new(false, 1, None, 0, 2);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
state.append(entry1); state.append(entry1);
state.append(entry2); state.append(entry2);
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap(); let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
assert_eq!(entries.len(), 1); assert_eq!(entries.len(), 1);
assert!(entries.contains_key(&0)); assert!(entries.contains_key(&0));
assert_eq!(batch.id, 0); assert_eq!(batch.id, 0);
@ -539,7 +599,7 @@ mod tests {
let (entry3, _guard3) = default_entry(); let (entry3, _guard3) = default_entry();
state.append(entry3); state.append(entry3);
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap(); let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
assert_eq!(entries.len(), 2); assert_eq!(entries.len(), 2);
assert!(entries.contains_key(&1)); assert!(entries.contains_key(&1));
assert!(entries.contains_key(&2)); assert!(entries.contains_key(&2));
@ -553,14 +613,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_append() { async fn test_queue_append() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(false, 1, None, 0, 16);
let (entry, _guard) = default_entry(); let (entry, _guard) = default_entry();
queue.append(entry); queue.append(entry);
} }
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_empty() { async fn test_queue_next_batch_empty() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(false, 1, None, 0, 16);
assert!(queue.next_batch(None, None, 1, 1).await.is_none()); assert!(queue.next_batch(None, None, 1, 1).await.is_none());
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
@ -568,7 +628,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_min_size() { async fn test_queue_next_batch_min_size() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -601,7 +661,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_max_size() { async fn test_queue_next_batch_max_size() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -617,7 +677,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_budget() { async fn test_queue_next_batch_token_budget() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -642,7 +702,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_token_speculate() { async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2); let queue = Queue::new(false, 1, None, 2, 16);
let (entry1, _guard1) = default_entry(); let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry(); let (entry2, _guard2) = default_entry();
queue.append(entry1); queue.append(entry1);
@ -661,7 +721,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_queue_next_batch_dropped_receiver() { async fn test_queue_next_batch_dropped_receiver() {
let queue = Queue::new(false, 1, None, 0); let queue = Queue::new(false, 1, None, 0, 16);
let (entry, _) = default_entry(); let (entry, _) = default_entry();
queue.append(entry); queue.append(entry);

View File

@ -39,7 +39,13 @@ impl SchedulerV3 {
speculate: u32, speculate: u32,
generation_health: Arc<AtomicBool>, generation_health: Arc<AtomicBool>,
) -> Self { ) -> Self {
let queue = Queue::new(requires_padding, 16, window_size, speculate); let queue = Queue::new(
requires_padding,
16,
window_size,
speculate,
max_batch_total_tokens,
);
let batching_task_notifier = Arc::new(Notify::new()); let batching_task_notifier = Arc::new(Notify::new());
// Spawn batching background task that contains all the inference logic // Spawn batching background task that contains all the inference logic
@ -81,6 +87,7 @@ impl Scheduler for SchedulerV3 {
temp_span: None, temp_span: None,
queue_time: Instant::now(), queue_time: Instant::now(),
batch_time: None, batch_time: None,
block_allocation: None,
}); });
// Notify the background task that we have a new entry in the queue that needs // Notify the background task that we have a new entry in the queue that needs

View File

@ -1,140 +0,0 @@
import math
import torch
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
BLOCK_SIZE: int = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None
class CacheManager:
def __init__(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
repeat_slots: bool,
dtype: torch.dtype,
device: torch.device,
):
self.block_size = BLOCK_SIZE
self.num_blocks = num_blocks
self.repeat_slots = repeat_slots
element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "xpu":
x = 1
else:
x = self.block_size // element_size
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, self.block_size, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, self.block_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
self.slots = torch.arange(
0, num_blocks * self.block_size, dtype=torch.int64
).view(num_blocks, self.block_size)
def allocate(
self,
needed_blocks_slots: List[Tuple[int, int]],
blocks: int,
max_blocks: int,
device: torch.device,
):
# Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero()
if blocks > len(free_block_indices):
raise RuntimeError(
f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
)
# Slice by the number of required blocks
block_indices = free_block_indices[:blocks]
block_indices = block_indices.flatten()
# Padded block tables
block_tables_tensor = torch.zeros(
(len(needed_blocks_slots), max_blocks), dtype=torch.int32
)
# Allocate paged attention blocks
cumulative_blocks = 0
slots = []
block_tables = []
for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
# Get allocated blocks for this sequence
allocated_blocks = block_indices[
cumulative_blocks : cumulative_blocks + needed_blocks
]
# Get slots for the allocated blocks
all_slots = self.slots[allocated_blocks].flatten()
# Repeat slots in the case of context sliding window
if needed_slots > len(all_slots) and self.repeat_slots:
repeats = math.ceil(needed_slots / len(all_slots))
all_slots = all_slots.repeat(repeats)
allocated_slots = all_slots[:needed_slots]
slots.append(allocated_slots)
block_tables.append(allocated_blocks.tolist())
block_tables_tensor[i, :needed_blocks] = allocated_blocks
cumulative_blocks += needed_blocks
block_tables = block_tables
block_tables_tensor = block_tables_tensor.to(device)
slots = torch.concat(slots).to(device)
# Allocate the required number of blocks by setting the mask to 0
self.free_block_mask[block_indices] = 0
return block_tables, block_tables_tensor, slots
def free(self, block_indices: Optional[List[int]]):
if block_indices is not None and block_indices:
# Reset mask
self.free_block_mask[block_indices] = 1
def set_cache_manager(
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
repeat_slots: bool,
dtype: torch.dtype,
device: torch.device,
) -> CacheManager:
global CACHE_MANAGER
if CACHE_MANAGER is not None:
del CACHE_MANAGER
torch.cuda.empty_cache()
CACHE_MANAGER = CacheManager(
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
)
return CACHE_MANAGER
def get_cache_manager() -> CacheManager:
global CACHE_MANAGER
if CACHE_MANAGER is None:
raise RuntimeError("cache manager was not initialized")
return CACHE_MANAGER

View File

@ -512,6 +512,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(

View File

@ -834,6 +834,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(

View File

@ -458,6 +458,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids) input_embeds = self.embed_tokens(input_ids)

View File

@ -388,6 +388,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.gpt_neox( hidden_states = self.gpt_neox(

View File

@ -398,6 +398,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model( hidden_states = self.model(

View File

@ -670,6 +670,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(

View File

@ -482,6 +482,7 @@ class FlashSantacoderForCausalLM(nn.Module):
slots: torch.Tensor, slots: torch.Tensor,
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(

View File

@ -25,11 +25,6 @@ from text_generation_server.models.types import (
Generation, Generation,
GeneratedText, GeneratedText,
) )
from text_generation_server.models.cache_manager import (
get_cache_manager,
set_cache_manager,
BLOCK_SIZE,
)
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
import text_generation_server.models.globals as tgi_globals import text_generation_server.models.globals as tgi_globals
@ -44,6 +39,21 @@ from text_generation_server.utils.import_utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
BLOCK_SIZE: int = 16
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
def set_sliding_window(sliding_window: int):
global SLIDING_WINDOW
SLIDING_WINDOW = sliding_window
def get_sliding_windows() -> int:
global SLIDING_WINDOW
return SLIDING_WINDOW
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
@ -55,12 +65,15 @@ class FlashCausalLMBatch(Batch):
# Decoder values # Decoder values
input_ids: torch.Tensor input_ids: torch.Tensor
position_ids: torch.Tensor position_ids: torch.Tensor
speculative_ids: torch.Tensor speculative_ids: Optional[torch.Tensor]
# Flash Attention values # Flash Attention values
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor] cu_seqlen_prefill: Optional[torch.Tensor]
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor]
# Paged Attention values # Paged Attention values
@ -69,16 +82,13 @@ class FlashCausalLMBatch(Batch):
start_slots: torch.Tensor start_slots: torch.Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices: torch.Tensor slot_indices: torch.Tensor
# List of tuple of ints representing the number of blocks and slots needed by each sequence
needed_blocks_slots: Optional[List[Tuple[int, int]]]
# Set in prefill by the CacheManager
# list of length b of list of length s_i // block_size # list of length b of list of length s_i // block_size
block_tables: Optional[List[List[int]]] block_tables: List[List[int]]
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: Optional[torch.Tensor] block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: Optional[torch.Tensor] slots: torch.Tensor
max_seqlen: int max_seqlen: int
@ -104,7 +114,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor: torch.Tensor top_n_tokens_tensor: torch.Tensor
# Number of blocks in this batch # Number of blocks in this batch
blocks: int num_blocks: int
# Maximum number of blocks # Maximum number of blocks
max_blocks: int max_blocks: int
@ -113,7 +123,7 @@ class FlashCausalLMBatch(Batch):
id=self.batch_id, id=self.batch_id,
request_ids=[r.id for r in self.requests], request_ids=[r.id for r in self.requests],
size=len(self), size=len(self),
max_tokens=self.blocks * BLOCK_SIZE, max_tokens=self.num_blocks * BLOCK_SIZE,
) )
@classmethod @classmethod
@ -129,17 +139,6 @@ class FlashCausalLMBatch(Batch):
)["input_ids"] )["input_ids"]
return batch_tokenized_inputs return batch_tokenized_inputs
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@classmethod @classmethod
def from_tokenized( def from_tokenized(
cls, cls,
@ -149,12 +148,12 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
sliding_window = get_sliding_windows()
position_ids = [] position_ids = []
speculative_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
needed_blocks_slots = []
start_slots = [] start_slots = []
slot_indices = [] slot_indices = []
prefill_cache_indices = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
@ -177,11 +176,14 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length = 0 cumulative_max_length = 0
prefill_out_cumulative_length = 0 prefill_out_cumulative_length = 0
blocks = 0 num_blocks = 0
max_seqlen = 0 max_seqlen = 0
max_length = 0 max_length = 0
max_blocks = 0 max_blocks = 0
block_tables = []
slots = []
# Parse batch # Parse batch
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs) zip(pb.requests, batch_tokenized_inputs)
@ -225,9 +227,25 @@ class FlashCausalLMBatch(Batch):
speculative_length = get_speculate() speculative_length = get_speculate()
speculative_length = 0 if speculative_length is None else speculative_length speculative_length = 0 if speculative_length is None else speculative_length
total_tokens = input_length + max_new_tokens - 1 + speculative_length total_tokens = input_length + max_new_tokens - 1 + speculative_length
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
blocks += needed_blocks # blocks and slots can be empty (for example in warmup)
needed_blocks_slots.append((needed_blocks, total_tokens)) if not r.blocks:
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks)
]
request_slots = [
s
for b in request_blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_blocks = r.blocks
request_slots = r.slots
block_tables.append(request_blocks)
slots.extend(request_slots[:total_tokens])
num_blocks += len(request_blocks)
start_slots.append(cumulative_max_length) start_slots.append(cumulative_max_length)
request_slot_indices = torch.arange( request_slot_indices = torch.arange(
@ -237,6 +255,15 @@ class FlashCausalLMBatch(Batch):
) )
slot_indices.append(request_slot_indices) slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
prefill_cache_indices.append(request_prefill_cache_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
@ -261,7 +288,7 @@ class FlashCausalLMBatch(Batch):
cumulative_length += input_length cumulative_length += input_length
cumulative_max_length += total_tokens cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks) max_blocks = max(max_blocks, len(request_blocks))
max_length = max( max_length = max(
max_length, input_length + max_new_tokens + speculative_length max_length, input_length + max_new_tokens + speculative_length
) )
@ -287,16 +314,23 @@ class FlashCausalLMBatch(Batch):
input_ids = np.concatenate(all_input_ids, dtype=np.int64) input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids) position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices) slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else: else:
input_ids = all_input_ids[0] input_ids = all_input_ids[0]
position_ids = position_ids[0] position_ids = position_ids[0]
slot_indices = slot_indices[0] slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32 cu_seqlen_prefill, device=device, dtype=torch.int32
) )
position_ids = position_ids.to(device) position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor( input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device input_lengths, dtype=torch.int32, device=device
@ -319,6 +353,14 @@ class FlashCausalLMBatch(Batch):
top_n_tokens, device=device, dtype=torch.int64 top_n_tokens, device=device, dtype=torch.int64
) )
slots = torch.tensor(slots, dtype=torch.int64, device=device)
block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device)
return cls( return cls(
batch_id=pb.id, batch_id=pb.id,
requests=pb.requests, requests=pb.requests,
@ -326,12 +368,12 @@ class FlashCausalLMBatch(Batch):
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
prefill_cache_indices=prefill_cache_indices,
start_slots=start_slots, start_slots=start_slots,
slot_indices=slot_indices, slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots, block_tables=block_tables,
block_tables=None, block_tables_tensor=block_tables_tensor,
block_tables_tensor=None, slots=slots,
slots=None,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices, prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices, prefill_next_token_indices=prefill_next_token_indices,
@ -346,11 +388,22 @@ class FlashCausalLMBatch(Batch):
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=None, speculative_ids=None,
) )
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
if len(request_ids) == 0: if len(request_ids) == 0:
@ -388,7 +441,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = [] stopping_criterias = []
top_n_tokens = [] top_n_tokens = []
blocks = 0 num_blocks = 0
max_blocks = 0 max_blocks = 0
# Cumulative length # Cumulative length
cumulative_max_length = 0 cumulative_max_length = 0
@ -420,7 +473,7 @@ class FlashCausalLMBatch(Batch):
) )
request_block_table = self.block_tables[idx] request_block_table = self.block_tables[idx]
blocks += len(request_block_table) num_blocks += len(request_block_table)
block_tables.append(request_block_table) block_tables.append(request_block_table)
start_slots.append(cumulative_max_length) start_slots.append(cumulative_max_length)
@ -439,17 +492,6 @@ class FlashCausalLMBatch(Batch):
max_blocks = max(max_blocks, len(request_block_table)) max_blocks = max(max_blocks, len(request_block_table))
block_indices_to_free = []
# Iterate on all requests
for i, r in enumerate(self.requests):
# Filter requests that are not part of the new batch
if r.id not in requests_idx_mapping.keys():
block_indices_to_free.extend(self.block_tables[i])
# Free blocks
get_cache_manager().free(block_indices_to_free)
# Needed to avoid dropping blocks when the batches will go out of scope
self.block_tables = None
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
@ -475,9 +517,9 @@ class FlashCausalLMBatch(Batch):
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
prefill_cache_indices=None,
start_slots=start_slots, start_slots=start_slots,
slot_indices=slot_indices, slot_indices=slot_indices,
needed_blocks_slots=None,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, slots=slots,
@ -495,7 +537,7 @@ class FlashCausalLMBatch(Batch):
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
) )
@ -507,7 +549,7 @@ class FlashCausalLMBatch(Batch):
requests = [] requests = []
requests_idx_mapping = {} requests_idx_mapping = {}
blocks = 0 num_blocks = 0
total_batch_size = 0 total_batch_size = 0
total_slots = 0 total_slots = 0
max_blocks = 0 max_blocks = 0
@ -516,7 +558,7 @@ class FlashCausalLMBatch(Batch):
for b in batches: for b in batches:
total_batch_size += len(b) total_batch_size += len(b)
total_slots += len(b.slots) total_slots += len(b.slots)
blocks += b.blocks num_blocks += b.num_blocks
speculative_length = ( speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
) )
@ -635,11 +677,6 @@ class FlashCausalLMBatch(Batch):
else None else None
) )
# Needed to avoid dropping blocks when the batches will go out of scope
for b in batches:
b.block_tables = None
del b
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
@ -647,9 +684,9 @@ class FlashCausalLMBatch(Batch):
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
prefill_cache_indices=None,
start_slots=start_slots, start_slots=start_slots,
slot_indices=slot_indices, slot_indices=slot_indices,
needed_blocks_slots=None,
block_tables=block_tables, block_tables=block_tables,
block_tables_tensor=block_tables_tensor, block_tables_tensor=block_tables_tensor,
slots=slots, slots=slots,
@ -667,18 +704,11 @@ class FlashCausalLMBatch(Batch):
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor, top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks, num_blocks=num_blocks,
max_blocks=max_blocks, max_blocks=max_blocks,
speculative_ids=speculative_ids, speculative_ids=speculative_ids,
) )
def __del__(self):
if self.block_tables is not None and self.block_tables:
# Free blocks
get_cache_manager().free(
list(itertools.chain.from_iterable(self.block_tables))
)
def __len__(self): def __len__(self):
return len(self.requests) return len(self.requests)
@ -702,6 +732,7 @@ class FlashCausalLM(Model):
self.head_size = head_size self.head_size = head_size
self.cuda_graphs = {} self.cuda_graphs = {}
self.kv_cache = []
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model, model=model,
@ -718,6 +749,43 @@ class FlashCausalLM(Model):
def batch_type(self) -> Type[FlashCausalLMBatch]: def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch return FlashCausalLMBatch
def max_past(self) -> int:
return getattr(self.model, "max_past", None)
def init_kv_cache(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.kv_cache = []
empty_cache()
element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "xpu":
x = 1
else:
x = BLOCK_SIZE // element_size
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
@ -728,12 +796,11 @@ class FlashCausalLM(Model):
.repeat(bs) .repeat(bs)
.reshape((bs, max_bt)) .reshape((bs, max_bt))
) )
kv_cache = get_cache_manager().kv_cache
self.cuda_graphs[bs] = { self.cuda_graphs[bs] = {
"input_ids": input_ids, "input_ids": input_ids,
"position_ids": position_ids, "position_ids": position_ids,
"kv_cache": kv_cache, "kv_cache": self.kv_cache,
"block_tables": block_tables, "block_tables": block_tables,
"slots": slots, "slots": slots,
"input_lengths": input_lengths, "input_lengths": input_lengths,
@ -747,11 +814,12 @@ class FlashCausalLM(Model):
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
kv_cache=kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, input_lengths=input_lengths,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -761,11 +829,12 @@ class FlashCausalLM(Model):
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
kv_cache=kv_cache, kv_cache=self.kv_cache,
block_tables=block_tables, block_tables=block_tables,
slots=slots, slots=slots,
input_lengths=input_lengths, input_lengths=input_lengths,
max_s=max_s, max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
) )
self.cuda_graphs[bs]["logits"] = logits self.cuda_graphs[bs]["logits"] = logits
@ -777,17 +846,16 @@ class FlashCausalLM(Model):
empty_cache() empty_cache()
try: try:
cache_manager = set_cache_manager( self.init_kv_cache(
batch.blocks, batch.num_blocks,
self.num_layers, self.num_layers,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.sliding_window is not None,
self.dtype, self.dtype,
self.device, self.device,
) )
max_bt = batch.max_blocks max_bt = batch.max_blocks
max_s = max_bt * get_cache_manager().block_size max_s = max_bt * BLOCK_SIZE
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
@ -811,19 +879,17 @@ class FlashCausalLM(Model):
num_blocks = ( num_blocks = (
# Leave 5% for some wiggle room # Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size) int((free_memory * 0.95) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory. # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ cache_manager.num_blocks + batch.num_blocks
) )
del batch del batch
del cache_manager
set_cache_manager( self.init_kv_cache(
num_blocks, num_blocks,
self.num_layers, self.num_layers,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.sliding_window is not None,
self.dtype, self.dtype,
self.device, self.device,
) )
@ -889,7 +955,6 @@ class FlashCausalLM(Model):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache
# Dummy value, some models (starcoder2) don't accept `None`. # Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
@ -901,12 +966,13 @@ class FlashCausalLM(Model):
cu_seqlen_prefill=torch.tensor( cu_seqlen_prefill=torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32 [0, seqlen], device=self.device, dtype=torch.int32
), ),
kv_cache=get_cache_manager().kv_cache, kv_cache=self.kv_cache,
block_tables=None, block_tables=None,
input_lengths=input_lengths, input_lengths=input_lengths,
slots=slots, slots=slots,
max_s=seqlen, max_s=seqlen,
lm_head_indices=None, lm_head_indices=None,
prefill_cache_indices=None,
) )
def forward( def forward(
@ -917,7 +983,7 @@ class FlashCausalLM(Model):
input_ids = batch.input_ids input_ids = batch.input_ids
position_ids = batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
@ -956,13 +1022,19 @@ class FlashCausalLM(Model):
input_ids = batch.input_ids input_ids = batch.input_ids
position_ids = batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0] bs = input_ids.shape[0]
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs: if sorted_padded_bs:
@ -972,7 +1044,7 @@ class FlashCausalLM(Model):
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
return self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
@ -981,8 +1053,12 @@ class FlashCausalLM(Model):
slots=slots, slots=slots,
input_lengths=input_lengths, input_lengths=input_lengths,
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
) )
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded # Static inputs are potentially padded
@ -1015,24 +1091,7 @@ class FlashCausalLM(Model):
prefill = batch.cu_seqlen_prefill is not None prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None prefill_logprobs = batch.prefill_next_token_indices is not None
if batch.needed_blocks_slots: out, speculative_logits = self.forward(batch)
# Allocate blocks to this batch
block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
batch.needed_blocks_slots,
batch.blocks,
batch.max_blocks,
batch.input_ids.device,
)
batch.needed_blocks_slots = None
batch.block_tables = block_tables
batch.block_tables_tensor = block_tables_tensor
batch.slots = slots
try:
out, speculative_logits = self.forward(batch)
except Exception as e:
del batch
raise e
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
@ -1327,7 +1386,6 @@ class FlashCausalLM(Model):
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
if stopped: if stopped:
del batch
# No need to return a batch if we know that all requests stopped # No need to return a batch if we know that all requests stopped
forward_ns = start_decode - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode

View File

@ -1,308 +1,24 @@
import math
import torch import torch
import torch.distributed import torch.distributed
import numpy as np
from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple, Type from typing import Optional, Tuple
from text_generation_server.pb import generate_pb2
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE from text_generation_server.models.flash_causal_lm import set_sliding_window
from text_generation_server.models.cache_manager import (
get_cache_manager,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM, FlashMistralForCausalLM,
MistralConfig, MistralConfig,
) )
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
HeterogeneousNextTokenChooser,
StoppingCriteria,
) )
tracer = trace.get_tracer(__name__)
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None tracer = trace.get_tracer(__name__)
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
SLIDING_WINDOW = sliding_window
SLIDING_WINDOW_BLOCKS = sliding_window_blocks
def get_sliding_windows() -> Tuple[int, int]:
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS
# Adds windowing logic to FlashCausalLMBatch
@dataclass
class FlashMistralBatch(FlashCausalLMBatch):
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices: Optional[torch.Tensor] = None
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
@classmethod
def from_tokenized(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
batch_tokenized_inputs,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
sliding_window, sliding_window_blocks = get_sliding_windows()
position_ids = []
cu_seqlen_prefill = [0]
needed_blocks_slots = []
start_slots = []
slot_indices = []
prefill_cache_indices = []
input_lengths = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
requests_idx_mapping = {}
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
next_token_chooser_parameters = []
stopping_criterias = []
top_n_tokens = []
# Cumulative length
cumulative_length = 0
cumulative_max_length = 0
prefill_out_cumulative_length = 0
blocks = 0
max_seqlen = 0
max_length = 0
max_blocks = 0
# Parse batch
for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs)
):
# request id -> idx in list mapping
requests_idx_mapping[r.id] = i
tokenized_input = tokenized_input[-r.truncate :]
if (
tokenized_input[0] == tokenizer.bos_token_id
and tokenized_input[1] == tokenizer.bos_token_id
):
tokenized_input = tokenized_input[1:]
input_length = len(tokenized_input)
input_lengths.append(input_length)
prefix_offsets.append(input_length - 5)
read_offsets.append(input_length)
all_input_ids.append(tokenized_input)
# Position ids
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length)
next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
)
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
top_n_tokens.append(r.top_n_tokens)
# Paged attention
# Remove one as the first token des not have a past
speculative_length = get_speculate()
total_tokens = input_length + max_new_tokens - 1 + speculative_length
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
if sliding_window_blocks is not None:
needed_blocks = min(needed_blocks, sliding_window_blocks)
blocks += needed_blocks
needed_blocks_slots.append((needed_blocks, total_tokens))
start_slots.append(cumulative_max_length)
request_slot_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
request_prefill_cache_indices = torch.arange(
cumulative_length + max(0, input_length - sliding_window),
cumulative_length + input_length,
dtype=torch.int64,
)
prefill_cache_indices.append(request_prefill_cache_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
if r.prefill_logprobs:
prefill_head_indices.append(request_position_ids + cumulative_length)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1], dtype=torch.int32
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
# Update
cumulative_length += input_length
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, needed_blocks)
max_length = max(
max_length, input_length + max_new_tokens + speculative_length
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
(len(all_input_ids), max_length), dtype=np.int64
)
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None
)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
)
prefill_next_token_indices = torch.tensor(
prefill_next_token_indices, dtype=torch.int64, device=device
)
top_n_tokens_tensor = torch.tensor(
top_n_tokens, device=device, dtype=torch.int64
)
return cls(
batch_id=pb.id,
requests=pb.requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
start_slots=start_slots,
slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots,
block_tables=None,
block_tables_tensor=None,
slots=None,
max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
prefill_cache_indices=prefill_cache_indices,
speculative_ids=None,
)
class BaseFlashMistral(FlashCausalLM): class BaseFlashMistral(FlashCausalLM):
@ -344,9 +60,7 @@ class BaseFlashMistral(FlashCausalLM):
# Set context windows # Set context windows
if getattr(config, "sliding_window", None) is not None: if getattr(config, "sliding_window", None) is not None:
set_sliding_window( set_sliding_window(config.sliding_window)
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
else: else:
config.sliding_window = None config.sliding_window = None
@ -384,207 +98,6 @@ class BaseFlashMistral(FlashCausalLM):
model.model.head_size, model.model.head_size,
) )
def max_past(self) -> int:
return self.model.max_past
@property
def batch_type(self) -> Type[FlashMistralBatch]:
return FlashMistralBatch
def tunableop_warmup(self, seqlen: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
),
kv_cache=get_cache_manager().kv_cache,
block_tables=None,
input_lengths=input_lengths,
slots=slots,
max_s=seqlen,
lm_head_indices=None,
prefill_cache_indices=None,
)
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
block_tables = (
torch.arange(max_bt, dtype=torch.int32, device=self.device)
.repeat(bs)
.reshape((bs, max_bt))
)
kv_cache = get_cache_manager().kv_cache
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths,
}
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize()
def forward(
self, batch: FlashMistralBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(-1)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
class FlashMistral(BaseFlashMistral): class FlashMistral(BaseFlashMistral):
def __init__( def __init__(

View File

@ -7,7 +7,6 @@ from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig from transformers import AutoTokenizer, AutoConfig
from typing import Optional from typing import Optional
from text_generation_server.models.cache_manager import BLOCK_SIZE
from text_generation_server.models.flash_mistral import ( from text_generation_server.models.flash_mistral import (
BaseFlashMistral, BaseFlashMistral,
set_sliding_window, set_sliding_window,
@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral):
# Set context windows # Set context windows
if config.sliding_window is not None: if config.sliding_window is not None:
set_sliding_window( set_sliding_window(config.sliding_window)
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -6,7 +6,6 @@ from typing import Optional
from transformers.models.gpt2 import GPT2TokenizerFast from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models.cache_manager import BLOCK_SIZE
from text_generation_server.models.flash_mistral import ( from text_generation_server.models.flash_mistral import (
BaseFlashMistral, BaseFlashMistral,
set_sliding_window, set_sliding_window,
@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral):
# Set context windows # Set context windows
if config.sliding_window is not None: if config.sliding_window is not None:
set_sliding_window( set_sliding_window(config.sliding_window)
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.flash_mistral import ( from text_generation_server.models.flash_mistral import (
BaseFlashMistral, BaseFlashMistral,
FlashMistralBatch,
)
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.cache_manager import (
get_cache_manager,
) )
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image:
return image return image
class VlmCausalLMBatch(FlashMistralBatch): class VlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]] pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]] pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]] image_sizes: Optional[List[Tuple[int, int]]]
@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids = batch.input_ids input_ids = batch.input_ids
position_ids = batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral):
input_ids = batch.input_ids input_ids = batch.input_ids
position_ids = batch.position_ids position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = get_cache_manager().kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor