mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
parent
cdd120ac02
commit
20df9234a9
@ -26,7 +26,12 @@ incremental = true
|
|||||||
inherits = "release"
|
inherits = "release"
|
||||||
debug = 1
|
debug = 1
|
||||||
incremental = true
|
incremental = true
|
||||||
|
panic = "abort"
|
||||||
|
|
||||||
|
[profile.release-opt]
|
||||||
|
inherits = "release"
|
||||||
|
debug = 0
|
||||||
|
incremental = false
|
||||||
lto = "fat"
|
lto = "fat"
|
||||||
opt-level = 3
|
opt-level = 3
|
||||||
codegen-units = 1
|
codegen-units = 1
|
||||||
panic = "abort"
|
|
||||||
|
@ -25,7 +25,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
|||||||
rm -f $PROTOC_ZIP
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
RUN cargo chef cook --release --recipe-path recipe.json
|
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||||
|
|
||||||
COPY Cargo.toml Cargo.toml
|
COPY Cargo.toml Cargo.toml
|
||||||
COPY rust-toolchain.toml rust-toolchain.toml
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
@ -33,7 +33,7 @@ COPY proto proto
|
|||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --release
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
# Text Generation Inference base image for RoCm
|
# Text Generation Inference base image for RoCm
|
||||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
|
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
|
||||||
@ -193,11 +193,11 @@ RUN cd server && \
|
|||||||
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
# Install router
|
# Install router
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
# AWS Sagemaker compatible image
|
# AWS Sagemaker compatible image
|
||||||
FROM base as sagemaker
|
FROM base as sagemaker
|
||||||
|
@ -24,7 +24,7 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
|
|||||||
rm -f $PROTOC_ZIP
|
rm -f $PROTOC_ZIP
|
||||||
|
|
||||||
COPY --from=planner /usr/src/recipe.json recipe.json
|
COPY --from=planner /usr/src/recipe.json recipe.json
|
||||||
RUN cargo chef cook --release --recipe-path recipe.json
|
RUN cargo chef cook --profile release-opt --recipe-path recipe.json
|
||||||
|
|
||||||
COPY Cargo.toml Cargo.toml
|
COPY Cargo.toml Cargo.toml
|
||||||
COPY rust-toolchain.toml rust-toolchain.toml
|
COPY rust-toolchain.toml rust-toolchain.toml
|
||||||
@ -32,7 +32,7 @@ COPY proto proto
|
|||||||
COPY benchmark benchmark
|
COPY benchmark benchmark
|
||||||
COPY router router
|
COPY router router
|
||||||
COPY launcher launcher
|
COPY launcher launcher
|
||||||
RUN cargo build --release
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
|
|
||||||
# Text Generation Inference base image for Intel
|
# Text Generation Inference base image for Intel
|
||||||
@ -78,11 +78,11 @@ ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mp
|
|||||||
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
ENV CCL_ZE_IPC_EXCHANGE=sockets
|
||||||
|
|
||||||
# Install benchmarker
|
# Install benchmarker
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
# Install router
|
# Install router
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
|
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
# Final image
|
# Final image
|
||||||
FROM base
|
FROM base
|
||||||
|
@ -155,6 +155,8 @@ async fn prefill(
|
|||||||
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||||
}),
|
}),
|
||||||
top_n_tokens: top_n_tokens.unwrap_or(0),
|
top_n_tokens: top_n_tokens.unwrap_or(0),
|
||||||
|
blocks: vec![],
|
||||||
|
slots: vec![],
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
@ -163,6 +165,7 @@ async fn prefill(
|
|||||||
requests,
|
requests,
|
||||||
size: batch_size,
|
size: batch_size,
|
||||||
max_tokens: batch_size * (sequence_length + decode_length),
|
max_tokens: batch_size * (sequence_length + decode_length),
|
||||||
|
max_blocks: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Run prefill
|
// Run prefill
|
||||||
|
@ -130,6 +130,10 @@ message Request {
|
|||||||
bool prefill_logprobs = 6;
|
bool prefill_logprobs = 6;
|
||||||
/// Return most likely n tokens
|
/// Return most likely n tokens
|
||||||
uint32 top_n_tokens = 7;
|
uint32 top_n_tokens = 7;
|
||||||
|
/// Paged attention blocks
|
||||||
|
repeated uint32 blocks = 9;
|
||||||
|
/// Paged attention slots
|
||||||
|
repeated uint32 slots = 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Batch {
|
message Batch {
|
||||||
@ -141,6 +145,8 @@ message Batch {
|
|||||||
uint32 size = 3;
|
uint32 size = 3;
|
||||||
/// Maximum number of tokens this batch will grow to
|
/// Maximum number of tokens this batch will grow to
|
||||||
uint32 max_tokens = 4;
|
uint32 max_tokens = 4;
|
||||||
|
/// Maximum number of Paged Attention blocks
|
||||||
|
uint32 max_blocks = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message CachedBatch {
|
message CachedBatch {
|
||||||
|
@ -153,6 +153,9 @@ impl Client {
|
|||||||
}),
|
}),
|
||||||
// We truncate the input on the server side to be sure that it has the correct size
|
// We truncate the input on the server side to be sure that it has the correct size
|
||||||
truncate,
|
truncate,
|
||||||
|
// Blocks and slots will be set on the server side if we use paged attention
|
||||||
|
blocks: vec![],
|
||||||
|
slots: vec![],
|
||||||
// Set sampling parameters to also take these ops into account in the max memory
|
// Set sampling parameters to also take these ops into account in the max memory
|
||||||
parameters: Some(NextTokenChooserParameters {
|
parameters: Some(NextTokenChooserParameters {
|
||||||
temperature: 0.9,
|
temperature: 0.9,
|
||||||
@ -187,7 +190,8 @@ impl Client {
|
|||||||
id: 0,
|
id: 0,
|
||||||
size: requests.len() as u32,
|
size: requests.len() as u32,
|
||||||
requests,
|
requests,
|
||||||
max_tokens: 0,
|
max_tokens: max_input_length,
|
||||||
|
max_blocks: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let request = tonic::Request::new(WarmupRequest {
|
let request = tonic::Request::new(WarmupRequest {
|
||||||
|
@ -241,12 +241,16 @@ impl Health for ShardedClient {
|
|||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
}),
|
}),
|
||||||
top_n_tokens: 0,
|
top_n_tokens: 0,
|
||||||
|
// Block 0 is reserved for health checks
|
||||||
|
blocks: vec![0],
|
||||||
|
slots: (0..16).collect(),
|
||||||
};
|
};
|
||||||
let batch = Batch {
|
let batch = Batch {
|
||||||
id: u64::MAX,
|
id: u64::MAX,
|
||||||
requests: vec![liveness_request],
|
requests: vec![liveness_request],
|
||||||
size: 1,
|
size: 1,
|
||||||
max_tokens: 2,
|
max_tokens: 2,
|
||||||
|
max_blocks: 1,
|
||||||
};
|
};
|
||||||
self.clone().prefill(batch).await?;
|
self.clone().prefill(batch).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
136
router/src/infer/v3/block_allocator.rs
Normal file
136
router/src/infer/v3/block_allocator.rs
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
use std::cmp::min;
|
||||||
|
use tokio::sync::{mpsc, oneshot};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct BlockAllocation {
|
||||||
|
pub blocks: Vec<u32>,
|
||||||
|
pub slots: Vec<u32>,
|
||||||
|
block_allocator: BlockAllocator,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for BlockAllocation {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.block_allocator.free(self.blocks.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct BlockAllocator {
|
||||||
|
/// Channel to communicate with the background task
|
||||||
|
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BlockAllocator {
|
||||||
|
pub(crate) fn new(
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
) -> Self {
|
||||||
|
// Create channel
|
||||||
|
let (sender, receiver) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// Launch background queue task
|
||||||
|
tokio::spawn(block_allocator_task(
|
||||||
|
max_batch_total_tokens / block_size,
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
|
receiver,
|
||||||
|
));
|
||||||
|
|
||||||
|
Self {
|
||||||
|
block_allocator: sender,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
|
||||||
|
let (response_sender, response_receiver) = oneshot::channel();
|
||||||
|
self.block_allocator
|
||||||
|
.send(BlockAllocatorCommand::Allocate {
|
||||||
|
tokens,
|
||||||
|
response_sender,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
response_receiver
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.map(|(blocks, slots)| BlockAllocation {
|
||||||
|
blocks,
|
||||||
|
slots,
|
||||||
|
block_allocator: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||||
|
self.block_allocator
|
||||||
|
.send(BlockAllocatorCommand::Free { blocks })
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn block_allocator_task(
|
||||||
|
blocks: u32,
|
||||||
|
block_size: u32,
|
||||||
|
window_size: Option<u32>,
|
||||||
|
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
|
||||||
|
) {
|
||||||
|
// Block 0 is reserved for health checks
|
||||||
|
let mut free_blocks: Vec<u32> = (1..blocks).collect();
|
||||||
|
while let Some(cmd) = receiver.recv().await {
|
||||||
|
match cmd {
|
||||||
|
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
|
||||||
|
BlockAllocatorCommand::Allocate {
|
||||||
|
tokens,
|
||||||
|
response_sender,
|
||||||
|
} => {
|
||||||
|
// Apply window size
|
||||||
|
let (required_blocks, repeats) = {
|
||||||
|
let (tokens, repeats) = match window_size {
|
||||||
|
None => (tokens, 1),
|
||||||
|
Some(window_size) => {
|
||||||
|
let repeats = (tokens + window_size - 1) / window_size;
|
||||||
|
let tokens = min(tokens, window_size);
|
||||||
|
(tokens, repeats as usize)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Pad to a multiple of block size
|
||||||
|
let required_blocks = (tokens + block_size - 1) / block_size;
|
||||||
|
(required_blocks, repeats)
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokens = tokens as usize;
|
||||||
|
let allocation = if required_blocks > free_blocks.len() as u32 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let blocks =
|
||||||
|
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
|
||||||
|
let mut slots = Vec::with_capacity(
|
||||||
|
(required_blocks * block_size * repeats as u32) as usize,
|
||||||
|
);
|
||||||
|
|
||||||
|
'slots: for block_id in blocks.repeat(repeats).iter() {
|
||||||
|
for s in (block_id * block_size)..((block_id + 1) * block_size) {
|
||||||
|
slots.push(s);
|
||||||
|
if slots.len() == tokens {
|
||||||
|
break 'slots;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some((blocks, slots))
|
||||||
|
};
|
||||||
|
response_sender.send(allocation).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum BlockAllocatorCommand {
|
||||||
|
Free {
|
||||||
|
blocks: Vec<u32>,
|
||||||
|
},
|
||||||
|
Allocate {
|
||||||
|
tokens: u32,
|
||||||
|
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
|
||||||
|
},
|
||||||
|
}
|
@ -1,3 +1,4 @@
|
|||||||
|
mod block_allocator;
|
||||||
mod queue;
|
mod queue;
|
||||||
mod scheduler;
|
mod scheduler;
|
||||||
|
|
||||||
|
@ -1,17 +1,20 @@
|
|||||||
use crate::infer::{InferError, InferStreamResponse};
|
use crate::infer::v3::block_allocator::{BlockAllocation, BlockAllocator};
|
||||||
|
use crate::infer::InferError;
|
||||||
|
use crate::infer::InferStreamResponse;
|
||||||
use crate::validation::{
|
use crate::validation::{
|
||||||
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,
|
||||||
};
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::min;
|
use std::cmp::{max, min};
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
use text_generation_client::v3::{
|
use text_generation_client::v3::{
|
||||||
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
use text_generation_client::{ChunksToString, Input};
|
use text_generation_client::ChunksToString;
|
||||||
|
use text_generation_client::Input;
|
||||||
use tokio::sync::{mpsc, oneshot};
|
use tokio::sync::{mpsc, oneshot};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tracing::{info_span, instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
/// Queue entry
|
/// Queue entry
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -28,6 +31,8 @@ pub(crate) struct Entry {
|
|||||||
pub queue_time: Instant,
|
pub queue_time: Instant,
|
||||||
/// Instant when this entry was added to a batch
|
/// Instant when this entry was added to a batch
|
||||||
pub batch_time: Option<Instant>,
|
pub batch_time: Option<Instant>,
|
||||||
|
/// Block Allocation
|
||||||
|
pub block_allocation: Option<BlockAllocation>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Request Queue
|
/// Request Queue
|
||||||
@ -43,6 +48,7 @@ impl Queue {
|
|||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Create channel
|
// Create channel
|
||||||
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
|
||||||
@ -53,12 +59,14 @@ impl Queue {
|
|||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
|
max_batch_total_tokens,
|
||||||
queue_receiver,
|
queue_receiver,
|
||||||
));
|
));
|
||||||
|
|
||||||
Self { queue_sender }
|
Self { queue_sender }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Append an entry to the queue
|
||||||
#[instrument(skip_all)]
|
#[instrument(skip_all)]
|
||||||
pub(crate) fn append(&self, entry: Entry) {
|
pub(crate) fn append(&self, entry: Entry) {
|
||||||
// Send append command to the background task managing the state
|
// Send append command to the background task managing the state
|
||||||
@ -103,9 +111,16 @@ async fn queue_task(
|
|||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
|
||||||
) {
|
) {
|
||||||
let mut state = State::new(requires_padding, block_size, window_size, speculate);
|
let mut state = State::new(
|
||||||
|
requires_padding,
|
||||||
|
block_size,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
while let Some(cmd) = receiver.recv().await {
|
while let Some(cmd) = receiver.recv().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
@ -120,12 +135,14 @@ async fn queue_task(
|
|||||||
token_budget,
|
token_budget,
|
||||||
response_sender,
|
response_sender,
|
||||||
span,
|
span,
|
||||||
} => span.in_scope(|| {
|
} => {
|
||||||
let next_batch =
|
let next_batch = state
|
||||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||||
|
.instrument(span)
|
||||||
|
.await;
|
||||||
response_sender.send(next_batch).unwrap();
|
response_sender.send(next_batch).unwrap();
|
||||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||||
}),
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -142,9 +159,6 @@ struct State {
|
|||||||
/// Id of the next batch
|
/// Id of the next batch
|
||||||
next_batch_id: u64,
|
next_batch_id: u64,
|
||||||
|
|
||||||
/// Whether the model is using padding
|
|
||||||
requires_padding: bool,
|
|
||||||
|
|
||||||
/// Paged Attention block size
|
/// Paged Attention block size
|
||||||
block_size: u32,
|
block_size: u32,
|
||||||
|
|
||||||
@ -153,6 +167,9 @@ struct State {
|
|||||||
|
|
||||||
/// Speculation amount
|
/// Speculation amount
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
|
||||||
|
/// Paged Attention Block Allocation
|
||||||
|
block_allocator: Option<BlockAllocator>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
@ -161,15 +178,19 @@ impl State {
|
|||||||
block_size: u32,
|
block_size: u32,
|
||||||
window_size: Option<u32>,
|
window_size: Option<u32>,
|
||||||
speculate: u32,
|
speculate: u32,
|
||||||
|
max_batch_total_tokens: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let block_allocator = (!requires_padding)
|
||||||
|
.then(|| BlockAllocator::new(max_batch_total_tokens, block_size, window_size));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
entries: VecDeque::with_capacity(128),
|
entries: VecDeque::with_capacity(128),
|
||||||
next_id: 0,
|
next_id: 0,
|
||||||
next_batch_id: 0,
|
next_batch_id: 0,
|
||||||
requires_padding,
|
|
||||||
block_size,
|
block_size,
|
||||||
window_size,
|
window_size,
|
||||||
speculate,
|
speculate,
|
||||||
|
block_allocator,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,7 +206,7 @@ impl State {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
fn next_batch(
|
async fn next_batch(
|
||||||
&mut self,
|
&mut self,
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
max_size: Option<usize>,
|
max_size: Option<usize>,
|
||||||
@ -220,9 +241,10 @@ impl State {
|
|||||||
let mut max_input_length = 0;
|
let mut max_input_length = 0;
|
||||||
let mut prefill_tokens: u32 = 0;
|
let mut prefill_tokens: u32 = 0;
|
||||||
let mut decode_tokens: u32 = 0;
|
let mut decode_tokens: u32 = 0;
|
||||||
|
let mut max_blocks = 0;
|
||||||
|
|
||||||
// Pop entries starting from the front of the queue
|
// Pop entries starting from the front of the queue
|
||||||
while let Some((id, mut entry)) = self.entries.pop_front() {
|
'entry_loop: while let Some((id, mut entry)) = self.entries.pop_front() {
|
||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
// was dropped by the client)
|
// was dropped by the client)
|
||||||
if entry.response_tx.is_closed() {
|
if entry.response_tx.is_closed() {
|
||||||
@ -231,43 +253,67 @@ impl State {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.requires_padding {
|
let block_allocation = match &self.block_allocator {
|
||||||
// We pad to max input length in the Python shards
|
None => {
|
||||||
// We need to take these padding tokens into the equation
|
// We pad to max input length in the Python shards
|
||||||
max_input_length = max_input_length.max(entry.request.input_length);
|
// We need to take these padding tokens into the equation
|
||||||
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
|
max_input_length = max_input_length.max(entry.request.input_length);
|
||||||
} else {
|
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length;
|
||||||
// pad to block size
|
|
||||||
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
|
|
||||||
/ self.block_size)
|
|
||||||
* self.block_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.requires_padding {
|
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
||||||
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
|
let total_tokens = prefill_tokens + decode_tokens + self.speculate;
|
||||||
} else {
|
|
||||||
let max_new_tokens = match self.window_size {
|
|
||||||
None => entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
Some(window_size) => min(
|
|
||||||
window_size.saturating_sub(entry.request.input_length),
|
|
||||||
entry.request.stopping_parameters.max_new_tokens,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
// pad to block size
|
if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
|
||||||
decode_tokens +=
|
// Entry is over budget
|
||||||
((max_new_tokens + self.block_size - 1) / self.block_size) * self.block_size;
|
// Add it back to the front
|
||||||
}
|
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break 'entry_loop;
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Some(block_allocator) => {
|
||||||
|
prefill_tokens += entry.request.input_length;
|
||||||
|
let max_new_tokens = match self.window_size {
|
||||||
|
None => entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
Some(window_size) => min(
|
||||||
|
window_size.saturating_sub(entry.request.input_length),
|
||||||
|
entry.request.stopping_parameters.max_new_tokens,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
decode_tokens += max_new_tokens;
|
||||||
|
|
||||||
if prefill_tokens > prefill_token_budget
|
if prefill_tokens > prefill_token_budget
|
||||||
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
|| (prefill_tokens + decode_tokens + self.speculate) > token_budget
|
||||||
{
|
{
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let tokens = entry.request.input_length
|
||||||
|
+ entry.request.stopping_parameters.max_new_tokens
|
||||||
|
+ self.speculate
|
||||||
|
- 1;
|
||||||
|
|
||||||
|
match block_allocator.allocate(tokens).await {
|
||||||
|
None => {
|
||||||
|
// Entry is over budget
|
||||||
|
// Add it back to the front
|
||||||
|
tracing::debug!("Over budget: not enough free blocks");
|
||||||
|
self.entries.push_front((id, entry));
|
||||||
|
break 'entry_loop;
|
||||||
|
}
|
||||||
|
Some(block_allocation) => {
|
||||||
|
tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
|
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||||
|
Some(block_allocation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
tracing::debug!("Accepting entry");
|
tracing::debug!("Accepting entry");
|
||||||
// Create a new span to link the batch back to this entry
|
// Create a new span to link the batch back to this entry
|
||||||
@ -278,13 +324,23 @@ impl State {
|
|||||||
// Update entry
|
// Update entry
|
||||||
entry.temp_span = Some(entry_batch_span);
|
entry.temp_span = Some(entry_batch_span);
|
||||||
|
|
||||||
|
let (blocks, slots) = match &block_allocation {
|
||||||
|
None => (Vec::new(), Vec::new()),
|
||||||
|
Some(block_allocation) => (
|
||||||
|
block_allocation.blocks.clone(),
|
||||||
|
block_allocation.slots.clone(),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
entry.block_allocation = block_allocation;
|
||||||
|
|
||||||
batch_requests.push(Request {
|
batch_requests.push(Request {
|
||||||
id,
|
id,
|
||||||
prefill_logprobs: entry.request.decoder_input_details,
|
prefill_logprobs: entry.request.decoder_input_details,
|
||||||
inputs: entry.request.inputs.chunks_to_string(),
|
|
||||||
input_chunks: Some(Input {
|
input_chunks: Some(Input {
|
||||||
chunks: entry.request.inputs.clone(),
|
chunks: entry.request.inputs.clone(),
|
||||||
}),
|
}),
|
||||||
|
inputs: entry.request.inputs.chunks_to_string(),
|
||||||
truncate: entry.request.truncate,
|
truncate: entry.request.truncate,
|
||||||
parameters: Some(NextTokenChooserParameters::from(
|
parameters: Some(NextTokenChooserParameters::from(
|
||||||
entry.request.parameters.clone(),
|
entry.request.parameters.clone(),
|
||||||
@ -293,6 +349,8 @@ impl State {
|
|||||||
entry.request.stopping_parameters.clone(),
|
entry.request.stopping_parameters.clone(),
|
||||||
)),
|
)),
|
||||||
top_n_tokens: entry.request.top_n_tokens,
|
top_n_tokens: entry.request.top_n_tokens,
|
||||||
|
blocks,
|
||||||
|
slots,
|
||||||
});
|
});
|
||||||
// Set batch_time
|
// Set batch_time
|
||||||
entry.batch_time = Some(Instant::now());
|
entry.batch_time = Some(Instant::now());
|
||||||
@ -335,6 +393,7 @@ impl State {
|
|||||||
requests: batch_requests,
|
requests: batch_requests,
|
||||||
size,
|
size,
|
||||||
max_tokens: (prefill_tokens + decode_tokens),
|
max_tokens: (prefill_tokens + decode_tokens),
|
||||||
|
max_blocks,
|
||||||
};
|
};
|
||||||
// Increment batch id
|
// Increment batch id
|
||||||
self.next_batch_id += 1;
|
self.next_batch_id += 1;
|
||||||
@ -438,13 +497,14 @@ mod tests {
|
|||||||
temp_span: None,
|
temp_span: None,
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
|
block_allocation: None,
|
||||||
};
|
};
|
||||||
(entry, receiver_tx)
|
(entry, receiver_tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_append() {
|
async fn test_append() {
|
||||||
let mut state = State::new(false, 1, None, 0);
|
let mut state = State::new(false, 1, None, 0, 16);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
|
|
||||||
assert_eq!(state.next_id, 0);
|
assert_eq!(state.next_id, 0);
|
||||||
@ -458,23 +518,23 @@ mod tests {
|
|||||||
assert_eq!(id, 0);
|
assert_eq!(id, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_next_batch_empty() {
|
async fn test_next_batch_empty() {
|
||||||
let mut state = State::new(false, 1, None, 0);
|
let mut state = State::new(false, 1, None, 0, 16);
|
||||||
|
|
||||||
assert!(state.next_batch(None, None, 1, 1).is_none());
|
assert!(state.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(state.next_batch(Some(1), None, 1, 1).is_none());
|
assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_next_batch_min_size() {
|
async fn test_next_batch_min_size() {
|
||||||
let mut state = State::new(false, 1, None, 0);
|
let mut state = State::new(false, 1, None, 0, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
@ -490,7 +550,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
assert!(state.next_batch(Some(2), None, 2, 2).is_none());
|
assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());
|
||||||
|
|
||||||
assert_eq!(state.next_id, 3);
|
assert_eq!(state.next_id, 3);
|
||||||
assert_eq!(state.entries.len(), 1);
|
assert_eq!(state.entries.len(), 1);
|
||||||
@ -498,15 +558,15 @@ mod tests {
|
|||||||
assert_eq!(id, 2);
|
assert_eq!(id, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_next_batch_max_size() {
|
async fn test_next_batch_max_size() {
|
||||||
let mut state = State::new(false, 1, None, 0);
|
let mut state = State::new(false, 1, None, 0, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();
|
let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
assert!(entries.get(&0).unwrap().batch_time.is_some());
|
||||||
@ -518,15 +578,15 @@ mod tests {
|
|||||||
assert_eq!(state.next_batch_id, 1);
|
assert_eq!(state.next_batch_id, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_next_batch_token_budget() {
|
async fn test_next_batch_token_budget() {
|
||||||
let mut state = State::new(false, 1, None, 0);
|
let mut state = State::new(false, 1, None, 0, 2);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
state.append(entry1);
|
state.append(entry1);
|
||||||
state.append(entry2);
|
state.append(entry2);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();
|
||||||
assert_eq!(entries.len(), 1);
|
assert_eq!(entries.len(), 1);
|
||||||
assert!(entries.contains_key(&0));
|
assert!(entries.contains_key(&0));
|
||||||
assert_eq!(batch.id, 0);
|
assert_eq!(batch.id, 0);
|
||||||
@ -539,7 +599,7 @@ mod tests {
|
|||||||
let (entry3, _guard3) = default_entry();
|
let (entry3, _guard3) = default_entry();
|
||||||
state.append(entry3);
|
state.append(entry3);
|
||||||
|
|
||||||
let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();
|
let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();
|
||||||
assert_eq!(entries.len(), 2);
|
assert_eq!(entries.len(), 2);
|
||||||
assert!(entries.contains_key(&1));
|
assert!(entries.contains_key(&1));
|
||||||
assert!(entries.contains_key(&2));
|
assert!(entries.contains_key(&2));
|
||||||
@ -553,14 +613,14 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_append() {
|
async fn test_queue_append() {
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
let queue = Queue::new(false, 1, None, 0, 16);
|
||||||
let (entry, _guard) = default_entry();
|
let (entry, _guard) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_empty() {
|
async fn test_queue_next_batch_empty() {
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
let queue = Queue::new(false, 1, None, 0, 16);
|
||||||
|
|
||||||
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
assert!(queue.next_batch(None, None, 1, 1).await.is_none());
|
||||||
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());
|
||||||
@ -568,7 +628,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_min_size() {
|
async fn test_queue_next_batch_min_size() {
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
let queue = Queue::new(false, 1, None, 0, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -601,7 +661,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_max_size() {
|
async fn test_queue_next_batch_max_size() {
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
let queue = Queue::new(false, 1, None, 0, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -617,7 +677,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_budget() {
|
async fn test_queue_next_batch_token_budget() {
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
let queue = Queue::new(false, 1, None, 0, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -642,7 +702,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_token_speculate() {
|
async fn test_queue_next_batch_token_speculate() {
|
||||||
let queue = Queue::new(false, 1, None, 2);
|
let queue = Queue::new(false, 1, None, 2, 16);
|
||||||
let (entry1, _guard1) = default_entry();
|
let (entry1, _guard1) = default_entry();
|
||||||
let (entry2, _guard2) = default_entry();
|
let (entry2, _guard2) = default_entry();
|
||||||
queue.append(entry1);
|
queue.append(entry1);
|
||||||
@ -661,7 +721,7 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_queue_next_batch_dropped_receiver() {
|
async fn test_queue_next_batch_dropped_receiver() {
|
||||||
let queue = Queue::new(false, 1, None, 0);
|
let queue = Queue::new(false, 1, None, 0, 16);
|
||||||
let (entry, _) = default_entry();
|
let (entry, _) = default_entry();
|
||||||
queue.append(entry);
|
queue.append(entry);
|
||||||
|
|
||||||
|
@ -39,7 +39,13 @@ impl SchedulerV3 {
|
|||||||
speculate: u32,
|
speculate: u32,
|
||||||
generation_health: Arc<AtomicBool>,
|
generation_health: Arc<AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
let queue = Queue::new(
|
||||||
|
requires_padding,
|
||||||
|
16,
|
||||||
|
window_size,
|
||||||
|
speculate,
|
||||||
|
max_batch_total_tokens,
|
||||||
|
);
|
||||||
let batching_task_notifier = Arc::new(Notify::new());
|
let batching_task_notifier = Arc::new(Notify::new());
|
||||||
|
|
||||||
// Spawn batching background task that contains all the inference logic
|
// Spawn batching background task that contains all the inference logic
|
||||||
@ -81,6 +87,7 @@ impl Scheduler for SchedulerV3 {
|
|||||||
temp_span: None,
|
temp_span: None,
|
||||||
queue_time: Instant::now(),
|
queue_time: Instant::now(),
|
||||||
batch_time: None,
|
batch_time: None,
|
||||||
|
block_allocation: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Notify the background task that we have a new entry in the queue that needs
|
// Notify the background task that we have a new entry in the queue that needs
|
||||||
|
@ -1,140 +0,0 @@
|
|||||||
import math
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from typing import Optional, List, Tuple
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
BLOCK_SIZE: int = 16
|
|
||||||
# Will be set in warmup
|
|
||||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CacheManager:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_blocks: int,
|
|
||||||
num_layers: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
repeat_slots: bool,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
self.block_size = BLOCK_SIZE
|
|
||||||
self.num_blocks = num_blocks
|
|
||||||
self.repeat_slots = repeat_slots
|
|
||||||
|
|
||||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
|
||||||
if SYSTEM == "xpu":
|
|
||||||
x = 1
|
|
||||||
else:
|
|
||||||
x = self.block_size // element_size
|
|
||||||
|
|
||||||
self.kv_cache = [
|
|
||||||
(
|
|
||||||
torch.empty(
|
|
||||||
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
torch.empty(
|
|
||||||
(num_blocks, num_heads, head_size, self.block_size),
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
|
||||||
self.slots = torch.arange(
|
|
||||||
0, num_blocks * self.block_size, dtype=torch.int64
|
|
||||||
).view(num_blocks, self.block_size)
|
|
||||||
|
|
||||||
def allocate(
|
|
||||||
self,
|
|
||||||
needed_blocks_slots: List[Tuple[int, int]],
|
|
||||||
blocks: int,
|
|
||||||
max_blocks: int,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
# Get free blocks indices by finding values in mask that are not set to 0
|
|
||||||
free_block_indices = self.free_block_mask.nonzero()
|
|
||||||
if blocks > len(free_block_indices):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Slice by the number of required blocks
|
|
||||||
block_indices = free_block_indices[:blocks]
|
|
||||||
block_indices = block_indices.flatten()
|
|
||||||
|
|
||||||
# Padded block tables
|
|
||||||
block_tables_tensor = torch.zeros(
|
|
||||||
(len(needed_blocks_slots), max_blocks), dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
# Allocate paged attention blocks
|
|
||||||
cumulative_blocks = 0
|
|
||||||
slots = []
|
|
||||||
block_tables = []
|
|
||||||
for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
|
|
||||||
# Get allocated blocks for this sequence
|
|
||||||
allocated_blocks = block_indices[
|
|
||||||
cumulative_blocks : cumulative_blocks + needed_blocks
|
|
||||||
]
|
|
||||||
# Get slots for the allocated blocks
|
|
||||||
all_slots = self.slots[allocated_blocks].flatten()
|
|
||||||
|
|
||||||
# Repeat slots in the case of context sliding window
|
|
||||||
if needed_slots > len(all_slots) and self.repeat_slots:
|
|
||||||
repeats = math.ceil(needed_slots / len(all_slots))
|
|
||||||
all_slots = all_slots.repeat(repeats)
|
|
||||||
|
|
||||||
allocated_slots = all_slots[:needed_slots]
|
|
||||||
|
|
||||||
slots.append(allocated_slots)
|
|
||||||
block_tables.append(allocated_blocks.tolist())
|
|
||||||
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
|
||||||
cumulative_blocks += needed_blocks
|
|
||||||
|
|
||||||
block_tables = block_tables
|
|
||||||
block_tables_tensor = block_tables_tensor.to(device)
|
|
||||||
slots = torch.concat(slots).to(device)
|
|
||||||
|
|
||||||
# Allocate the required number of blocks by setting the mask to 0
|
|
||||||
self.free_block_mask[block_indices] = 0
|
|
||||||
|
|
||||||
return block_tables, block_tables_tensor, slots
|
|
||||||
|
|
||||||
def free(self, block_indices: Optional[List[int]]):
|
|
||||||
if block_indices is not None and block_indices:
|
|
||||||
# Reset mask
|
|
||||||
self.free_block_mask[block_indices] = 1
|
|
||||||
|
|
||||||
|
|
||||||
def set_cache_manager(
|
|
||||||
num_blocks: int,
|
|
||||||
num_layers: int,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
repeat_slots: bool,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
) -> CacheManager:
|
|
||||||
global CACHE_MANAGER
|
|
||||||
if CACHE_MANAGER is not None:
|
|
||||||
del CACHE_MANAGER
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
CACHE_MANAGER = CacheManager(
|
|
||||||
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
|
|
||||||
)
|
|
||||||
return CACHE_MANAGER
|
|
||||||
|
|
||||||
|
|
||||||
def get_cache_manager() -> CacheManager:
|
|
||||||
global CACHE_MANAGER
|
|
||||||
if CACHE_MANAGER is None:
|
|
||||||
raise RuntimeError("cache manager was not initialized")
|
|
||||||
|
|
||||||
return CACHE_MANAGER
|
|
@ -512,6 +512,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
|||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
|
@ -834,6 +834,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
|||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
|
@ -458,6 +458,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
input_embeds = self.embed_tokens(input_ids)
|
input_embeds = self.embed_tokens(input_ids)
|
||||||
|
@ -388,6 +388,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.gpt_neox(
|
hidden_states = self.gpt_neox(
|
||||||
|
@ -398,6 +398,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
|||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
|
@ -670,6 +670,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(
|
hidden_states = self.transformer(
|
||||||
|
@ -482,6 +482,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(
|
hidden_states = self.transformer(
|
||||||
|
@ -25,11 +25,6 @@ from text_generation_server.models.types import (
|
|||||||
Generation,
|
Generation,
|
||||||
GeneratedText,
|
GeneratedText,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.cache_manager import (
|
|
||||||
get_cache_manager,
|
|
||||||
set_cache_manager,
|
|
||||||
BLOCK_SIZE,
|
|
||||||
)
|
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
|
||||||
import text_generation_server.models.globals as tgi_globals
|
import text_generation_server.models.globals as tgi_globals
|
||||||
@ -44,6 +39,21 @@ from text_generation_server.utils.import_utils import (
|
|||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
BLOCK_SIZE: int = 16
|
||||||
|
|
||||||
|
# Will be set in init
|
||||||
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_sliding_window(sliding_window: int):
|
||||||
|
global SLIDING_WINDOW
|
||||||
|
SLIDING_WINDOW = sliding_window
|
||||||
|
|
||||||
|
|
||||||
|
def get_sliding_windows() -> int:
|
||||||
|
global SLIDING_WINDOW
|
||||||
|
return SLIDING_WINDOW
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
@ -55,12 +65,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Decoder values
|
# Decoder values
|
||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: torch.Tensor
|
||||||
speculative_ids: torch.Tensor
|
speculative_ids: Optional[torch.Tensor]
|
||||||
|
|
||||||
# Flash Attention values
|
# Flash Attention values
|
||||||
|
|
||||||
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor]
|
cu_seqlen_prefill: Optional[torch.Tensor]
|
||||||
|
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
||||||
|
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor]
|
||||||
|
|
||||||
# Paged Attention values
|
# Paged Attention values
|
||||||
|
|
||||||
@ -69,16 +82,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_slots: torch.Tensor
|
start_slots: torch.Tensor
|
||||||
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
|
||||||
slot_indices: torch.Tensor
|
slot_indices: torch.Tensor
|
||||||
# List of tuple of ints representing the number of blocks and slots needed by each sequence
|
|
||||||
needed_blocks_slots: Optional[List[Tuple[int, int]]]
|
|
||||||
|
|
||||||
# Set in prefill by the CacheManager
|
|
||||||
# list of length b of list of length s_i // block_size
|
# list of length b of list of length s_i // block_size
|
||||||
block_tables: Optional[List[List[int]]]
|
block_tables: List[List[int]]
|
||||||
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||||
block_tables_tensor: Optional[torch.Tensor]
|
block_tables_tensor: torch.Tensor
|
||||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||||
slots: Optional[torch.Tensor]
|
slots: torch.Tensor
|
||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
@ -104,7 +114,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens_tensor: torch.Tensor
|
top_n_tokens_tensor: torch.Tensor
|
||||||
|
|
||||||
# Number of blocks in this batch
|
# Number of blocks in this batch
|
||||||
blocks: int
|
num_blocks: int
|
||||||
# Maximum number of blocks
|
# Maximum number of blocks
|
||||||
max_blocks: int
|
max_blocks: int
|
||||||
|
|
||||||
@ -113,7 +123,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
request_ids=[r.id for r in self.requests],
|
request_ids=[r.id for r in self.requests],
|
||||||
size=len(self),
|
size=len(self),
|
||||||
max_tokens=self.blocks * BLOCK_SIZE,
|
max_tokens=self.num_blocks * BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -129,17 +139,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
return batch_tokenized_inputs
|
return batch_tokenized_inputs
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pb(
|
|
||||||
cls,
|
|
||||||
pb: generate_pb2.Batch,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
) -> "FlashCausalLMBatch":
|
|
||||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
|
||||||
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tokenized(
|
def from_tokenized(
|
||||||
cls,
|
cls,
|
||||||
@ -149,12 +148,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
|
sliding_window = get_sliding_windows()
|
||||||
position_ids = []
|
position_ids = []
|
||||||
speculative_ids = []
|
|
||||||
cu_seqlen_prefill = [0]
|
cu_seqlen_prefill = [0]
|
||||||
needed_blocks_slots = []
|
|
||||||
start_slots = []
|
start_slots = []
|
||||||
slot_indices = []
|
slot_indices = []
|
||||||
|
prefill_cache_indices = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
@ -177,11 +176,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cumulative_max_length = 0
|
cumulative_max_length = 0
|
||||||
prefill_out_cumulative_length = 0
|
prefill_out_cumulative_length = 0
|
||||||
|
|
||||||
blocks = 0
|
num_blocks = 0
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
max_length = 0
|
max_length = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
|
|
||||||
|
block_tables = []
|
||||||
|
slots = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
zip(pb.requests, batch_tokenized_inputs)
|
zip(pb.requests, batch_tokenized_inputs)
|
||||||
@ -225,9 +227,25 @@ class FlashCausalLMBatch(Batch):
|
|||||||
speculative_length = get_speculate()
|
speculative_length = get_speculate()
|
||||||
speculative_length = 0 if speculative_length is None else speculative_length
|
speculative_length = 0 if speculative_length is None else speculative_length
|
||||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
|
||||||
blocks += needed_blocks
|
# blocks and slots can be empty (for example in warmup)
|
||||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
if not r.blocks:
|
||||||
|
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
||||||
|
request_blocks = [
|
||||||
|
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||||
|
]
|
||||||
|
request_slots = [
|
||||||
|
s
|
||||||
|
for b in request_blocks
|
||||||
|
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
request_blocks = r.blocks
|
||||||
|
request_slots = r.slots
|
||||||
|
|
||||||
|
block_tables.append(request_blocks)
|
||||||
|
slots.extend(request_slots[:total_tokens])
|
||||||
|
num_blocks += len(request_blocks)
|
||||||
start_slots.append(cumulative_max_length)
|
start_slots.append(cumulative_max_length)
|
||||||
|
|
||||||
request_slot_indices = torch.arange(
|
request_slot_indices = torch.arange(
|
||||||
@ -237,6 +255,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
slot_indices.append(request_slot_indices)
|
slot_indices.append(request_slot_indices)
|
||||||
|
|
||||||
|
# Create tensor to slice into the kv tensor in prefill
|
||||||
|
if sliding_window is not None:
|
||||||
|
request_prefill_cache_indices = torch.arange(
|
||||||
|
cumulative_length + max(0, input_length - sliding_window),
|
||||||
|
cumulative_length + input_length,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||||
|
|
||||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||||
|
|
||||||
@ -261,7 +288,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
cumulative_max_length += total_tokens
|
cumulative_max_length += total_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, input_length)
|
||||||
max_blocks = max(max_blocks, needed_blocks)
|
max_blocks = max(max_blocks, len(request_blocks))
|
||||||
max_length = max(
|
max_length = max(
|
||||||
max_length, input_length + max_new_tokens + speculative_length
|
max_length, input_length + max_new_tokens + speculative_length
|
||||||
)
|
)
|
||||||
@ -287,16 +314,23 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||||
position_ids = torch.cat(position_ids)
|
position_ids = torch.cat(position_ids)
|
||||||
slot_indices = torch.cat(slot_indices)
|
slot_indices = torch.cat(slot_indices)
|
||||||
|
if sliding_window is not None:
|
||||||
|
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||||
else:
|
else:
|
||||||
input_ids = all_input_ids[0]
|
input_ids = all_input_ids[0]
|
||||||
position_ids = position_ids[0]
|
position_ids = position_ids[0]
|
||||||
slot_indices = slot_indices[0]
|
slot_indices = slot_indices[0]
|
||||||
|
if sliding_window is not None:
|
||||||
|
prefill_cache_indices = prefill_cache_indices[0]
|
||||||
|
|
||||||
cu_seqlen_prefill = torch.tensor(
|
cu_seqlen_prefill = torch.tensor(
|
||||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
position_ids = position_ids.to(device)
|
position_ids = position_ids.to(device)
|
||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
|
prefill_cache_indices = (
|
||||||
|
prefill_cache_indices.to(device) if sliding_window is not None else None
|
||||||
|
)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
input_lengths_tensor = torch.tensor(
|
input_lengths_tensor = torch.tensor(
|
||||||
input_lengths, dtype=torch.int32, device=device
|
input_lengths, dtype=torch.int32, device=device
|
||||||
@ -319,6 +353,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens, device=device, dtype=torch.int64
|
top_n_tokens, device=device, dtype=torch.int64
|
||||||
)
|
)
|
||||||
|
|
||||||
|
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||||
|
block_tables_tensor = torch.zeros(
|
||||||
|
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||||
|
)
|
||||||
|
for i, request_blocks in enumerate(block_tables):
|
||||||
|
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||||
|
block_tables_tensor = block_tables_tensor.to(device)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
@ -326,12 +368,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
start_slots=start_slots,
|
start_slots=start_slots,
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
needed_blocks_slots=needed_blocks_slots,
|
block_tables=block_tables,
|
||||||
block_tables=None,
|
block_tables_tensor=block_tables_tensor,
|
||||||
block_tables_tensor=None,
|
slots=slots,
|
||||||
slots=None,
|
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
@ -346,11 +388,22 @@ class FlashCausalLMBatch(Batch):
|
|||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
num_blocks=num_blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=None,
|
speculative_ids=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pb(
|
||||||
|
cls,
|
||||||
|
pb: generate_pb2.Batch,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "FlashCausalLMBatch":
|
||||||
|
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
||||||
|
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
@ -388,7 +441,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
top_n_tokens = []
|
top_n_tokens = []
|
||||||
|
|
||||||
blocks = 0
|
num_blocks = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_max_length = 0
|
cumulative_max_length = 0
|
||||||
@ -420,7 +473,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)
|
)
|
||||||
|
|
||||||
request_block_table = self.block_tables[idx]
|
request_block_table = self.block_tables[idx]
|
||||||
blocks += len(request_block_table)
|
num_blocks += len(request_block_table)
|
||||||
block_tables.append(request_block_table)
|
block_tables.append(request_block_table)
|
||||||
start_slots.append(cumulative_max_length)
|
start_slots.append(cumulative_max_length)
|
||||||
|
|
||||||
@ -439,17 +492,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
max_blocks = max(max_blocks, len(request_block_table))
|
max_blocks = max(max_blocks, len(request_block_table))
|
||||||
|
|
||||||
block_indices_to_free = []
|
|
||||||
# Iterate on all requests
|
|
||||||
for i, r in enumerate(self.requests):
|
|
||||||
# Filter requests that are not part of the new batch
|
|
||||||
if r.id not in requests_idx_mapping.keys():
|
|
||||||
block_indices_to_free.extend(self.block_tables[i])
|
|
||||||
# Free blocks
|
|
||||||
get_cache_manager().free(block_indices_to_free)
|
|
||||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
|
||||||
self.block_tables = None
|
|
||||||
|
|
||||||
# Index into tensors
|
# Index into tensors
|
||||||
input_ids = self.input_ids[indices]
|
input_ids = self.input_ids[indices]
|
||||||
position_ids = self.position_ids[indices]
|
position_ids = self.position_ids[indices]
|
||||||
@ -475,9 +517,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
|
prefill_cache_indices=None,
|
||||||
start_slots=start_slots,
|
start_slots=start_slots,
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
needed_blocks_slots=None,
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
@ -495,7 +537,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
num_blocks=num_blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=speculative_ids,
|
speculative_ids=speculative_ids,
|
||||||
)
|
)
|
||||||
@ -507,7 +549,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
blocks = 0
|
num_blocks = 0
|
||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
total_slots = 0
|
total_slots = 0
|
||||||
max_blocks = 0
|
max_blocks = 0
|
||||||
@ -516,7 +558,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
for b in batches:
|
for b in batches:
|
||||||
total_batch_size += len(b)
|
total_batch_size += len(b)
|
||||||
total_slots += len(b.slots)
|
total_slots += len(b.slots)
|
||||||
blocks += b.blocks
|
num_blocks += b.num_blocks
|
||||||
speculative_length = (
|
speculative_length = (
|
||||||
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||||
)
|
)
|
||||||
@ -635,11 +677,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
|
||||||
for b in batches:
|
|
||||||
b.block_tables = None
|
|
||||||
del b
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=batches[0].batch_id,
|
batch_id=batches[0].batch_id,
|
||||||
requests=requests,
|
requests=requests,
|
||||||
@ -647,9 +684,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
|
prefill_cache_indices=None,
|
||||||
start_slots=start_slots,
|
start_slots=start_slots,
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
needed_blocks_slots=None,
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
block_tables_tensor=block_tables_tensor,
|
block_tables_tensor=block_tables_tensor,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
@ -667,18 +704,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
num_blocks=num_blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
speculative_ids=speculative_ids,
|
speculative_ids=speculative_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if self.block_tables is not None and self.block_tables:
|
|
||||||
# Free blocks
|
|
||||||
get_cache_manager().free(
|
|
||||||
list(itertools.chain.from_iterable(self.block_tables))
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.requests)
|
return len(self.requests)
|
||||||
|
|
||||||
@ -702,6 +732,7 @@ class FlashCausalLM(Model):
|
|||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
|
|
||||||
self.cuda_graphs = {}
|
self.cuda_graphs = {}
|
||||||
|
self.kv_cache = []
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@ -718,6 +749,43 @@ class FlashCausalLM(Model):
|
|||||||
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
def batch_type(self) -> Type[FlashCausalLMBatch]:
|
||||||
return FlashCausalLMBatch
|
return FlashCausalLMBatch
|
||||||
|
|
||||||
|
def max_past(self) -> int:
|
||||||
|
return getattr(self.model, "max_past", None)
|
||||||
|
|
||||||
|
def init_kv_cache(
|
||||||
|
self,
|
||||||
|
num_blocks: int,
|
||||||
|
num_layers: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
self.kv_cache = []
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
if SYSTEM == "xpu":
|
||||||
|
x = 1
|
||||||
|
else:
|
||||||
|
x = BLOCK_SIZE // element_size
|
||||||
|
|
||||||
|
self.kv_cache = [
|
||||||
|
(
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
@ -728,12 +796,11 @@ class FlashCausalLM(Model):
|
|||||||
.repeat(bs)
|
.repeat(bs)
|
||||||
.reshape((bs, max_bt))
|
.reshape((bs, max_bt))
|
||||||
)
|
)
|
||||||
kv_cache = get_cache_manager().kv_cache
|
|
||||||
|
|
||||||
self.cuda_graphs[bs] = {
|
self.cuda_graphs[bs] = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"position_ids": position_ids,
|
"position_ids": position_ids,
|
||||||
"kv_cache": kv_cache,
|
"kv_cache": self.kv_cache,
|
||||||
"block_tables": block_tables,
|
"block_tables": block_tables,
|
||||||
"slots": slots,
|
"slots": slots,
|
||||||
"input_lengths": input_lengths,
|
"input_lengths": input_lengths,
|
||||||
@ -747,11 +814,12 @@ class FlashCausalLM(Model):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
kv_cache=kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -761,11 +829,12 @@ class FlashCausalLM(Model):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
kv_cache=kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
)
|
)
|
||||||
self.cuda_graphs[bs]["logits"] = logits
|
self.cuda_graphs[bs]["logits"] = logits
|
||||||
@ -777,17 +846,16 @@ class FlashCausalLM(Model):
|
|||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cache_manager = set_cache_manager(
|
self.init_kv_cache(
|
||||||
batch.blocks,
|
batch.num_blocks,
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.sliding_window is not None,
|
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
max_bt = batch.max_blocks
|
max_bt = batch.max_blocks
|
||||||
max_s = max_bt * get_cache_manager().block_size
|
max_s = max_bt * BLOCK_SIZE
|
||||||
|
|
||||||
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
|
||||||
torch.cuda.tunable.tuning_enable(False)
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
@ -811,19 +879,17 @@ class FlashCausalLM(Model):
|
|||||||
num_blocks = (
|
num_blocks = (
|
||||||
# Leave 5% for some wiggle room
|
# Leave 5% for some wiggle room
|
||||||
int((free_memory * 0.95) // total_cache_size)
|
int((free_memory * 0.95) // total_cache_size)
|
||||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
|
||||||
+ cache_manager.num_blocks
|
+ batch.num_blocks
|
||||||
)
|
)
|
||||||
|
|
||||||
del batch
|
del batch
|
||||||
del cache_manager
|
|
||||||
|
|
||||||
set_cache_manager(
|
self.init_kv_cache(
|
||||||
num_blocks,
|
num_blocks,
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.sliding_window is not None,
|
|
||||||
self.dtype,
|
self.dtype,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
@ -889,7 +955,6 @@ class FlashCausalLM(Model):
|
|||||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||||
kv_cache = get_cache_manager().kv_cache
|
|
||||||
|
|
||||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||||
@ -901,12 +966,13 @@ class FlashCausalLM(Model):
|
|||||||
cu_seqlen_prefill=torch.tensor(
|
cu_seqlen_prefill=torch.tensor(
|
||||||
[0, seqlen], device=self.device, dtype=torch.int32
|
[0, seqlen], device=self.device, dtype=torch.int32
|
||||||
),
|
),
|
||||||
kv_cache=get_cache_manager().kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=seqlen,
|
max_s=seqlen,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
|
prefill_cache_indices=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -917,7 +983,7 @@ class FlashCausalLM(Model):
|
|||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
position_ids = batch.position_ids
|
position_ids = batch.position_ids
|
||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache = get_cache_manager().kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
@ -956,13 +1022,19 @@ class FlashCausalLM(Model):
|
|||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
position_ids = batch.position_ids
|
position_ids = batch.position_ids
|
||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache = get_cache_manager().kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
|
# in a circular buffer mode.
|
||||||
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
if sorted_padded_bs:
|
if sorted_padded_bs:
|
||||||
@ -972,7 +1044,7 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
return self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
@ -981,8 +1053,12 @@ class FlashCausalLM(Model):
|
|||||||
slots=slots,
|
slots=slots,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
)
|
)
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
batch.prefill_cache_indices = None
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
# Static inputs are potentially padded
|
# Static inputs are potentially padded
|
||||||
@ -1015,24 +1091,7 @@ class FlashCausalLM(Model):
|
|||||||
prefill = batch.cu_seqlen_prefill is not None
|
prefill = batch.cu_seqlen_prefill is not None
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
if batch.needed_blocks_slots:
|
out, speculative_logits = self.forward(batch)
|
||||||
# Allocate blocks to this batch
|
|
||||||
block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
|
|
||||||
batch.needed_blocks_slots,
|
|
||||||
batch.blocks,
|
|
||||||
batch.max_blocks,
|
|
||||||
batch.input_ids.device,
|
|
||||||
)
|
|
||||||
batch.needed_blocks_slots = None
|
|
||||||
batch.block_tables = block_tables
|
|
||||||
batch.block_tables_tensor = block_tables_tensor
|
|
||||||
batch.slots = slots
|
|
||||||
|
|
||||||
try:
|
|
||||||
out, speculative_logits = self.forward(batch)
|
|
||||||
except Exception as e:
|
|
||||||
del batch
|
|
||||||
raise e
|
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
next_token_logits = (
|
next_token_logits = (
|
||||||
@ -1327,7 +1386,6 @@ class FlashCausalLM(Model):
|
|||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
|
|
||||||
if stopped:
|
if stopped:
|
||||||
del batch
|
|
||||||
# No need to return a batch if we know that all requests stopped
|
# No need to return a batch if we know that all requests stopped
|
||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
@ -1,308 +1,24 @@
|
|||||||
import math
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
|
from transformers import AutoTokenizer, AutoConfig
|
||||||
from typing import Optional, Tuple, Type
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from text_generation_server.pb import generate_pb2
|
|
||||||
from text_generation_server.models import FlashCausalLM
|
from text_generation_server.models import FlashCausalLM
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
|
from text_generation_server.models.flash_causal_lm import set_sliding_window
|
||||||
from text_generation_server.models.cache_manager import (
|
|
||||||
get_cache_manager,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
FlashMistralForCausalLM,
|
FlashMistralForCausalLM,
|
||||||
MistralConfig,
|
MistralConfig,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.speculate import get_speculate
|
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
HeterogeneousNextTokenChooser,
|
|
||||||
StoppingCriteria,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
# Will be set in init
|
|
||||||
SLIDING_WINDOW: Optional[int] = None
|
|
||||||
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
def set_sliding_window(sliding_window: int, sliding_window_blocks: int):
|
|
||||||
global SLIDING_WINDOW
|
|
||||||
global SLIDING_WINDOW_BLOCKS
|
|
||||||
SLIDING_WINDOW = sliding_window
|
|
||||||
SLIDING_WINDOW_BLOCKS = sliding_window_blocks
|
|
||||||
|
|
||||||
|
|
||||||
def get_sliding_windows() -> Tuple[int, int]:
|
|
||||||
global SLIDING_WINDOW
|
|
||||||
global SLIDING_WINDOW_BLOCKS
|
|
||||||
return SLIDING_WINDOW, SLIDING_WINDOW_BLOCKS
|
|
||||||
|
|
||||||
|
|
||||||
# Adds windowing logic to FlashCausalLMBatch
|
|
||||||
@dataclass
|
|
||||||
class FlashMistralBatch(FlashCausalLMBatch):
|
|
||||||
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
|
||||||
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
|
||||||
prefill_cache_indices: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pb(
|
|
||||||
cls,
|
|
||||||
pb: generate_pb2.Batch,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
) -> "FlashCausalLMBatch":
|
|
||||||
batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)
|
|
||||||
return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_tokenized(
|
|
||||||
cls,
|
|
||||||
pb: generate_pb2.Batch,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
batch_tokenized_inputs,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
) -> "FlashCausalLMBatch":
|
|
||||||
sliding_window, sliding_window_blocks = get_sliding_windows()
|
|
||||||
|
|
||||||
position_ids = []
|
|
||||||
cu_seqlen_prefill = [0]
|
|
||||||
needed_blocks_slots = []
|
|
||||||
start_slots = []
|
|
||||||
slot_indices = []
|
|
||||||
prefill_cache_indices = []
|
|
||||||
|
|
||||||
input_lengths = []
|
|
||||||
prefix_offsets = []
|
|
||||||
read_offsets = []
|
|
||||||
all_input_ids = []
|
|
||||||
requests_idx_mapping = {}
|
|
||||||
|
|
||||||
all_prefill_logprobs = True
|
|
||||||
no_prefill_logprobs = True
|
|
||||||
prefill_head_indices = []
|
|
||||||
prefill_next_token_indices = []
|
|
||||||
prefill_cu_outlens = [0]
|
|
||||||
|
|
||||||
next_token_chooser_parameters = []
|
|
||||||
stopping_criterias = []
|
|
||||||
top_n_tokens = []
|
|
||||||
|
|
||||||
# Cumulative length
|
|
||||||
cumulative_length = 0
|
|
||||||
cumulative_max_length = 0
|
|
||||||
prefill_out_cumulative_length = 0
|
|
||||||
|
|
||||||
blocks = 0
|
|
||||||
max_seqlen = 0
|
|
||||||
max_length = 0
|
|
||||||
max_blocks = 0
|
|
||||||
|
|
||||||
# Parse batch
|
|
||||||
for i, (r, tokenized_input) in enumerate(
|
|
||||||
zip(pb.requests, batch_tokenized_inputs)
|
|
||||||
):
|
|
||||||
# request id -> idx in list mapping
|
|
||||||
requests_idx_mapping[r.id] = i
|
|
||||||
|
|
||||||
tokenized_input = tokenized_input[-r.truncate :]
|
|
||||||
if (
|
|
||||||
tokenized_input[0] == tokenizer.bos_token_id
|
|
||||||
and tokenized_input[1] == tokenizer.bos_token_id
|
|
||||||
):
|
|
||||||
tokenized_input = tokenized_input[1:]
|
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
|
||||||
input_lengths.append(input_length)
|
|
||||||
|
|
||||||
prefix_offsets.append(input_length - 5)
|
|
||||||
read_offsets.append(input_length)
|
|
||||||
|
|
||||||
all_input_ids.append(tokenized_input)
|
|
||||||
|
|
||||||
# Position ids
|
|
||||||
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
|
||||||
position_ids.append(request_position_ids)
|
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
|
||||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
|
||||||
|
|
||||||
next_token_chooser_parameters.append(r.parameters)
|
|
||||||
|
|
||||||
stopping_criteria = StoppingCriteria.from_pb(
|
|
||||||
r.stopping_parameters, tokenizer
|
|
||||||
)
|
|
||||||
max_new_tokens = stopping_criteria.max_new_tokens
|
|
||||||
stopping_criterias.append(stopping_criteria)
|
|
||||||
top_n_tokens.append(r.top_n_tokens)
|
|
||||||
|
|
||||||
# Paged attention
|
|
||||||
# Remove one as the first token des not have a past
|
|
||||||
speculative_length = get_speculate()
|
|
||||||
total_tokens = input_length + max_new_tokens - 1 + speculative_length
|
|
||||||
|
|
||||||
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
|
||||||
needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
|
|
||||||
if sliding_window_blocks is not None:
|
|
||||||
needed_blocks = min(needed_blocks, sliding_window_blocks)
|
|
||||||
blocks += needed_blocks
|
|
||||||
|
|
||||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
|
||||||
start_slots.append(cumulative_max_length)
|
|
||||||
|
|
||||||
request_slot_indices = torch.arange(
|
|
||||||
cumulative_max_length,
|
|
||||||
cumulative_max_length + input_length,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
slot_indices.append(request_slot_indices)
|
|
||||||
|
|
||||||
# Create tensor to slice into the kv tensor in prefill
|
|
||||||
if sliding_window is not None:
|
|
||||||
request_prefill_cache_indices = torch.arange(
|
|
||||||
cumulative_length + max(0, input_length - sliding_window),
|
|
||||||
cumulative_length + input_length,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
|
||||||
|
|
||||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
|
||||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
|
||||||
|
|
||||||
if r.prefill_logprobs:
|
|
||||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
|
||||||
prefill_next_token_indices.append(
|
|
||||||
prefill_out_cumulative_length + input_length - 1
|
|
||||||
)
|
|
||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
|
||||||
prefill_out_cumulative_length += input_length
|
|
||||||
else:
|
|
||||||
prefill_head_indices.append(
|
|
||||||
torch.tensor(
|
|
||||||
[cumulative_length + input_length - 1], dtype=torch.int32
|
|
||||||
)
|
|
||||||
)
|
|
||||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
|
||||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
|
||||||
prefill_out_cumulative_length += 1
|
|
||||||
|
|
||||||
# Update
|
|
||||||
cumulative_length += input_length
|
|
||||||
cumulative_max_length += total_tokens
|
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
|
||||||
max_blocks = max(max_blocks, needed_blocks)
|
|
||||||
max_length = max(
|
|
||||||
max_length, input_length + max_new_tokens + speculative_length
|
|
||||||
)
|
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
|
||||||
next_token_chooser_parameters, dtype, device, tokenizer
|
|
||||||
)
|
|
||||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
|
||||||
|
|
||||||
# Padded all_input_ids_tensor
|
|
||||||
all_input_ids_tensor = np.zeros(
|
|
||||||
(len(all_input_ids), max_length), dtype=np.int64
|
|
||||||
)
|
|
||||||
for i, input_ids in enumerate(all_input_ids):
|
|
||||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
|
||||||
|
|
||||||
# Create tensors on device
|
|
||||||
all_input_ids_tensor = torch.tensor(
|
|
||||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(pb.requests) > 1:
|
|
||||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
|
||||||
position_ids = torch.cat(position_ids)
|
|
||||||
slot_indices = torch.cat(slot_indices)
|
|
||||||
if sliding_window is not None:
|
|
||||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
|
||||||
else:
|
|
||||||
input_ids = all_input_ids[0]
|
|
||||||
position_ids = position_ids[0]
|
|
||||||
slot_indices = slot_indices[0]
|
|
||||||
if sliding_window is not None:
|
|
||||||
prefill_cache_indices = prefill_cache_indices[0]
|
|
||||||
|
|
||||||
cu_seqlen_prefill = torch.tensor(
|
|
||||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
position_ids = position_ids.to(device)
|
|
||||||
slot_indices = slot_indices.to(device)
|
|
||||||
prefill_cache_indices = (
|
|
||||||
prefill_cache_indices.to(device) if sliding_window is not None else None
|
|
||||||
)
|
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
|
||||||
input_lengths_tensor = torch.tensor(
|
|
||||||
input_lengths, dtype=torch.int32, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
|
||||||
prefill_head_indices = None
|
|
||||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
|
||||||
elif no_prefill_logprobs:
|
|
||||||
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
|
||||||
prefill_next_token_indices = None
|
|
||||||
else:
|
|
||||||
prefill_head_indices = torch.tensor(
|
|
||||||
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
prefill_next_token_indices = torch.tensor(
|
|
||||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
top_n_tokens_tensor = torch.tensor(
|
|
||||||
top_n_tokens, device=device, dtype=torch.int64
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
batch_id=pb.id,
|
|
||||||
requests=pb.requests,
|
|
||||||
requests_idx_mapping=requests_idx_mapping,
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
||||||
start_slots=start_slots,
|
|
||||||
slot_indices=slot_indices,
|
|
||||||
needed_blocks_slots=needed_blocks_slots,
|
|
||||||
block_tables=None,
|
|
||||||
block_tables_tensor=None,
|
|
||||||
slots=None,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
prefill_head_indices=prefill_head_indices,
|
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
|
||||||
prefill_cu_outlens=prefill_cu_outlens,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
|
||||||
prefix_offsets=prefix_offsets,
|
|
||||||
read_offsets=read_offsets,
|
|
||||||
all_input_ids=all_input_ids,
|
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
|
||||||
next_token_chooser=next_token_chooser,
|
|
||||||
stopping_criterias=stopping_criterias,
|
|
||||||
top_n_tokens=top_n_tokens,
|
|
||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
|
||||||
blocks=blocks,
|
|
||||||
max_blocks=max_blocks,
|
|
||||||
prefill_cache_indices=prefill_cache_indices,
|
|
||||||
speculative_ids=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFlashMistral(FlashCausalLM):
|
class BaseFlashMistral(FlashCausalLM):
|
||||||
@ -344,9 +60,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
if getattr(config, "sliding_window", None) is not None:
|
if getattr(config, "sliding_window", None) is not None:
|
||||||
set_sliding_window(
|
set_sliding_window(config.sliding_window)
|
||||||
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
config.sliding_window = None
|
config.sliding_window = None
|
||||||
|
|
||||||
@ -384,207 +98,6 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
model.model.head_size,
|
model.model.head_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def max_past(self) -> int:
|
|
||||||
return self.model.max_past
|
|
||||||
|
|
||||||
@property
|
|
||||||
def batch_type(self) -> Type[FlashMistralBatch]:
|
|
||||||
return FlashMistralBatch
|
|
||||||
|
|
||||||
def tunableop_warmup(self, seqlen: int):
|
|
||||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
|
||||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
|
||||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
|
||||||
kv_cache = get_cache_manager().kv_cache
|
|
||||||
|
|
||||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
|
||||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
|
||||||
self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cu_seqlen_prefill=torch.tensor(
|
|
||||||
[0, seqlen], device=self.device, dtype=torch.int32
|
|
||||||
),
|
|
||||||
kv_cache=get_cache_manager().kv_cache,
|
|
||||||
block_tables=None,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
slots=slots,
|
|
||||||
max_s=seqlen,
|
|
||||||
lm_head_indices=None,
|
|
||||||
prefill_cache_indices=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
|
||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
|
||||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
|
||||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
|
||||||
input_lengths = torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
|
||||||
block_tables = (
|
|
||||||
torch.arange(max_bt, dtype=torch.int32, device=self.device)
|
|
||||||
.repeat(bs)
|
|
||||||
.reshape((bs, max_bt))
|
|
||||||
)
|
|
||||||
kv_cache = get_cache_manager().kv_cache
|
|
||||||
|
|
||||||
self.cuda_graphs[bs] = {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"position_ids": position_ids,
|
|
||||||
"kv_cache": kv_cache,
|
|
||||||
"block_tables": block_tables,
|
|
||||||
"slots": slots,
|
|
||||||
"input_lengths": input_lengths,
|
|
||||||
}
|
|
||||||
graph = torch.cuda.CUDAGraph()
|
|
||||||
self.cuda_graphs[bs]["graph"] = graph
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
# Run once outside to warmup
|
|
||||||
self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cu_seqlen_prefill=None,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
block_tables=block_tables,
|
|
||||||
slots=slots,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
max_s=max_s,
|
|
||||||
prefill_cache_indices=None,
|
|
||||||
lm_head_indices=None,
|
|
||||||
)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
|
||||||
logits, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cu_seqlen_prefill=None,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
block_tables=block_tables,
|
|
||||||
slots=slots,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
max_s=max_s,
|
|
||||||
prefill_cache_indices=None,
|
|
||||||
lm_head_indices=None,
|
|
||||||
)
|
|
||||||
self.cuda_graphs[bs]["logits"] = logits
|
|
||||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, batch: FlashMistralBatch
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
||||||
# Model Forward
|
|
||||||
if batch.speculative_ids is not None:
|
|
||||||
input_ids = batch.input_ids
|
|
||||||
position_ids = batch.position_ids
|
|
||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
||||||
kv_cache = get_cache_manager().kv_cache
|
|
||||||
block_tables = batch.block_tables_tensor
|
|
||||||
slots = batch.slots[batch.slot_indices]
|
|
||||||
input_lengths = batch.input_lengths_tensor
|
|
||||||
max_s = batch.max_seqlen
|
|
||||||
lm_head_indices = batch.prefill_head_indices
|
|
||||||
|
|
||||||
speculative_ids = batch.speculative_ids
|
|
||||||
|
|
||||||
B, speculative_length = speculative_ids.shape
|
|
||||||
new_length = speculative_length + 1
|
|
||||||
new_input_ids = torch.cat(
|
|
||||||
[input_ids.unsqueeze(-1), speculative_ids], dim=1
|
|
||||||
).reshape(-1)
|
|
||||||
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
|
|
||||||
arange_int = arange.to(dtype=torch.int32)
|
|
||||||
new_position_ids = (
|
|
||||||
position_ids.unsqueeze(-1).expand(B, new_length) + arange
|
|
||||||
).view(-1)
|
|
||||||
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
|
|
||||||
input_lengths = (
|
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
|
||||||
).view(-1)
|
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
|
||||||
block_tables = (
|
|
||||||
block_tables.unsqueeze(1)
|
|
||||||
.expand(B, new_length, -1)
|
|
||||||
.reshape(B * new_length, -1)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
max_s = max_s + speculative_length
|
|
||||||
|
|
||||||
input_ids = new_input_ids
|
|
||||||
position_ids = new_position_ids
|
|
||||||
else:
|
|
||||||
input_ids = batch.input_ids
|
|
||||||
position_ids = batch.position_ids
|
|
||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
|
||||||
kv_cache = get_cache_manager().kv_cache
|
|
||||||
block_tables = batch.block_tables_tensor
|
|
||||||
slots = batch.slots[batch.slot_indices]
|
|
||||||
input_lengths = batch.input_lengths_tensor
|
|
||||||
max_s = batch.max_seqlen
|
|
||||||
lm_head_indices = batch.prefill_head_indices
|
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
|
||||||
# in a circular buffer mode.
|
|
||||||
# This makes sure the max_s for the decode pass is correct.
|
|
||||||
max_s = min(self.max_past(), max_s)
|
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
|
||||||
padded_bs = bs
|
|
||||||
if bs == 3:
|
|
||||||
padded_bs = 4
|
|
||||||
elif 3 < bs <= 8:
|
|
||||||
padded_bs = 8
|
|
||||||
elif bs > 8:
|
|
||||||
padded_bs = (bs + 7) // 8 * 8
|
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
|
||||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
|
||||||
logits, speculative_logits = self.model.forward(
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
block_tables=block_tables,
|
|
||||||
slots=slots,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
max_s=max_s,
|
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
|
||||||
lm_head_indices=lm_head_indices,
|
|
||||||
)
|
|
||||||
if batch.prefill_cache_indices is not None:
|
|
||||||
batch.prefill_cache_indices = None
|
|
||||||
return logits, speculative_logits
|
|
||||||
|
|
||||||
# Copy inputs to the static inputs of the cuda graph
|
|
||||||
# Static inputs are potentially padded
|
|
||||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
|
||||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
|
||||||
cuda_graph["block_tables"][
|
|
||||||
: block_tables.shape[0], : block_tables.shape[1]
|
|
||||||
] = block_tables
|
|
||||||
cuda_graph["slots"].fill_(-1)
|
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
|
||||||
cuda_graph["input_lengths"].zero_()
|
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
|
||||||
|
|
||||||
# Replay the graph
|
|
||||||
cuda_graph["graph"].replay()
|
|
||||||
|
|
||||||
# Slice output to the correct shape
|
|
||||||
speculative_logits = (
|
|
||||||
cuda_graph["speculative_logits"][:bs]
|
|
||||||
if cuda_graph["speculative_logits"] is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
logits = cuda_graph["logits"][:bs]
|
|
||||||
return logits, speculative_logits
|
|
||||||
|
|
||||||
|
|
||||||
class FlashMistral(BaseFlashMistral):
|
class FlashMistral(BaseFlashMistral):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -7,7 +7,6 @@ from opentelemetry import trace
|
|||||||
from transformers import AutoTokenizer, AutoConfig
|
from transformers import AutoTokenizer, AutoConfig
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from text_generation_server.models.cache_manager import BLOCK_SIZE
|
|
||||||
from text_generation_server.models.flash_mistral import (
|
from text_generation_server.models.flash_mistral import (
|
||||||
BaseFlashMistral,
|
BaseFlashMistral,
|
||||||
set_sliding_window,
|
set_sliding_window,
|
||||||
@ -57,9 +56,7 @@ class FlashQwen2(BaseFlashMistral):
|
|||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
if config.sliding_window is not None:
|
if config.sliding_window is not None:
|
||||||
set_sliding_window(
|
set_sliding_window(config.sliding_window)
|
||||||
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@ from typing import Optional
|
|||||||
|
|
||||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||||
|
|
||||||
from text_generation_server.models.cache_manager import BLOCK_SIZE
|
|
||||||
from text_generation_server.models.flash_mistral import (
|
from text_generation_server.models.flash_mistral import (
|
||||||
BaseFlashMistral,
|
BaseFlashMistral,
|
||||||
set_sliding_window,
|
set_sliding_window,
|
||||||
@ -56,9 +55,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
|||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
if config.sliding_window is not None:
|
if config.sliding_window is not None:
|
||||||
set_sliding_window(
|
set_sliding_window(config.sliding_window)
|
||||||
config.sliding_window, math.ceil(config.sliding_window / BLOCK_SIZE)
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -11,13 +11,9 @@ from typing import Optional, Tuple, List, Type, Dict
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.image_processing_utils import select_best_resolution
|
from transformers.image_processing_utils import select_best_resolution
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
|
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
||||||
from text_generation_server.models.flash_mistral import (
|
from text_generation_server.models.flash_mistral import (
|
||||||
BaseFlashMistral,
|
BaseFlashMistral,
|
||||||
FlashMistralBatch,
|
|
||||||
)
|
|
||||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
|
||||||
from text_generation_server.models.cache_manager import (
|
|
||||||
get_cache_manager,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
@ -140,7 +136,7 @@ def load_data_uri(image_uri: str) -> Image.Image:
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
class VlmCausalLMBatch(FlashMistralBatch):
|
class VlmCausalLMBatch(FlashCausalLMBatch):
|
||||||
pixel_values: Optional[List[torch.Tensor]]
|
pixel_values: Optional[List[torch.Tensor]]
|
||||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||||
image_sizes: Optional[List[Tuple[int, int]]]
|
image_sizes: Optional[List[Tuple[int, int]]]
|
||||||
@ -268,7 +264,7 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
position_ids = batch.position_ids
|
position_ids = batch.position_ids
|
||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache = get_cache_manager().kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
@ -307,7 +303,7 @@ class VlmCausalLM(BaseFlashMistral):
|
|||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
position_ids = batch.position_ids
|
position_ids = batch.position_ids
|
||||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
kv_cache = get_cache_manager().kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user