All the assertions.

Invariants added

Remove the logs.
This commit is contained in:
Nicolas Patry 2025-02-28 17:41:22 +01:00
parent 463228ebfc
commit ddf0b02240
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
12 changed files with 394 additions and 30 deletions

22
Cargo.lock generated
View File

@ -1169,6 +1169,16 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "env_logger"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
dependencies = [
"log",
"regex",
]
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.1" version = "1.0.1"
@ -3403,6 +3413,17 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
[[package]]
name = "quickcheck"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6"
dependencies = [
"env_logger",
"log",
"rand",
]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.37" version = "1.0.37"
@ -4630,6 +4651,7 @@ dependencies = [
"opentelemetry-otlp", "opentelemetry-otlp",
"prost 0.12.6", "prost 0.12.6",
"prost-build", "prost-build",
"quickcheck",
"rand", "rand",
"regex", "regex",
"reqwest 0.11.27", "reqwest 0.11.27",

View File

@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*; use pb::generate::v3::*;
use std::cmp::min; use std::cmp::min;
use std::io::Error;
use std::time::Duration; use std::time::Duration;
use tonic::transport::{Channel, Uri}; use tonic::transport::{Channel, Uri};
use tracing::instrument; use tracing::instrument;
@ -232,6 +233,20 @@ impl Client {
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>, cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let slots: Vec<_> = batch
.requests
.iter()
.map(|r| &r.slots[r.cache_len as usize..])
.flatten()
.collect();
assert_eq!(
slots.len(),
slots.iter().collect::<std::collections::HashSet<_>>().len()
);
if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
std::process::exit(1);
}
let request = tonic::Request::new(PrefillRequest { let request = tonic::Request::new(PrefillRequest {
batch: Some(batch), batch: Some(batch),
cached_batch, cached_batch,

View File

@ -63,6 +63,7 @@ base64 = { workspace = true }
prost = "^0.12" prost = "^0.12"
tonic = "^0.10" tonic = "^0.10"
tower = "^0.4" tower = "^0.4"
quickcheck = "1.0.3"
[build-dependencies] [build-dependencies]
tonic-build = "0.10.1" tonic-build = "0.10.1"

View File

@ -5,6 +5,7 @@ use crate::client::{
use crate::queue::{Entry, Queue}; use crate::queue::{Entry, Queue};
use async_trait::async_trait; use async_trait::async_trait;
use nohash_hasher::IntMap; use nohash_hasher::IntMap;
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest; use text_generation_router::validation::ValidGenerateRequest;
@ -13,6 +14,7 @@ use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify}; use tokio::sync::{mpsc, Notify};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::IntoRequest;
use tracing::{info_span, instrument, Instrument, Span}; use tracing::{info_span, instrument, Instrument, Span};
pub struct BackendV3 { pub struct BackendV3 {
@ -121,6 +123,35 @@ impl Backend for BackendV3 {
} }
} }
impl Batch {
pub fn check(&self) -> Result<(), InferError> {
let slots: Vec<_> = self
.requests
.iter()
.map(|r| &r.slots[r.cache_len as usize..])
.flatten()
.collect();
// assert_eq!(
// slots.len(),
// slots.iter().collect::<std::collections::HashSet<_>>().len()
// );
if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
let mut map: HashMap<u32, usize> = HashMap::new();
for slot in slots {
*map.entry(*slot).or_default() += 1usize;
}
let duplicates: HashMap<_, _> = map.into_iter().filter(|(_slot, c)| *c > 1).collect();
Err(InferError::GenerationError(format!(
"Invalid batch: {duplicates:?}",
)))
} else {
Ok(())
}
}
}
/// Batching logic /// Batching logic
/// Will be launched in a background Tokio task /// Will be launched in a background Tokio task
/// ///
@ -154,6 +185,7 @@ pub(crate) async fn batching_task(
) )
.await .await
{ {
batch.check().unwrap();
let mut cached_batch = prefill(&mut client, batch, None, &mut entries) let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
.instrument(span) .instrument(span)
.await; .await;
@ -205,6 +237,7 @@ pub(crate) async fn batching_task(
.next_batch(min_size, max_size, prefill_token_budget, token_budget) .next_batch(min_size, max_size, prefill_token_budget, token_budget)
.await .await
{ {
new_batch.check().unwrap();
// Tracking metrics // Tracking metrics
if min_size.is_some() { if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure") metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
@ -225,6 +258,7 @@ pub(crate) async fn batching_task(
// concatenated during the prefill op server side // concatenated during the prefill op server side
entries.extend(new_entries); entries.extend(new_entries);
// Generate one token for both the cached batch and the new batch // Generate one token for both the cached batch and the new batch
new_batch.check().unwrap();
let new_cached_batch = let new_cached_batch =
prefill(&mut client, new_batch, cached_batch, &mut entries) prefill(&mut client, new_batch, cached_batch, &mut entries)
.instrument(span) .instrument(span)
@ -249,6 +283,7 @@ pub(crate) async fn batching_task(
}); });
// Generate one token for this new batch to have the attention past in cache // Generate one token for this new batch to have the attention past in cache
new_batch.check().unwrap();
let new_cached_batch = let new_cached_batch =
prefill(&mut client, new_batch, None, &mut new_entries) prefill(&mut client, new_batch, None, &mut new_entries)
.instrument(span) .instrument(span)

View File

@ -19,7 +19,13 @@ pub struct BlockAllocation {
impl Drop for BlockAllocation { impl Drop for BlockAllocation {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(block_allocator) = self.block_allocator.as_mut() { if let Some(block_allocator) = self.block_allocator.as_mut() {
tracing::debug!("Freeing block {}", self.allocation_id);
block_allocator.free(self.blocks.clone(), self.allocation_id) block_allocator.free(self.blocks.clone(), self.allocation_id)
} else {
#[cfg(not(test))]
{
panic!("We didn't have a block allocator");
}
} }
} }
} }

View File

@ -1,6 +1,7 @@
/// Single shard Client /// Single shard Client
use crate::client::{pb, Chunk}; use crate::client::{pb, Chunk};
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
use axum::http::Error;
use base64::engine::general_purpose::STANDARD; use base64::engine::general_purpose::STANDARD;
use base64::Engine; use base64::Engine;
use grpc_metadata::InjectTelemetryContext; use grpc_metadata::InjectTelemetryContext;
@ -232,11 +233,28 @@ impl Client {
batch: Batch, batch: Batch,
cached_batch: Option<CachedBatch>, cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> { ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let slots: Vec<_> = batch
.requests
.iter()
.map(|r| &r.slots[r.cache_len as usize..])
.flatten()
.collect();
assert_eq!(
slots.len(),
slots.iter().collect::<std::collections::HashSet<_>>().len()
);
if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
std::process::exit(1);
}
let request = tonic::Request::new(PrefillRequest { let request = tonic::Request::new(PrefillRequest {
batch: Some(batch), batch: Some(batch),
cached_batch, cached_batch,
}) })
.inject_context(); .inject_context();
// if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
// return Err(Error::from("Test"));
// }
let response = self.stub.prefill(request).await?.into_inner(); let response = self.stub.prefill(request).await?.into_inner();
Ok(( Ok((
response.generations, response.generations,

View File

@ -5,7 +5,7 @@ use crate::client::{
}; };
use nohash_hasher::{BuildNoHashHasher, IntMap}; use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::max; use std::cmp::max;
use std::collections::VecDeque; use std::collections::{HashMap, VecDeque};
use text_generation_router::infer::InferError; use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse; use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{ use text_generation_router::validation::{
@ -269,6 +269,8 @@ impl State {
let mut decode_tokens: u32 = 0; let mut decode_tokens: u32 = 0;
let mut max_blocks = 0; let mut max_blocks = 0;
let mut viewed: HashMap<u32, usize> = HashMap::new();
// Pop entries starting from the front of the queue // Pop entries starting from the front of the queue
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() { 'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request // Filter entries where the response receiver was dropped (== entries where the request
@ -311,7 +313,7 @@ impl State {
+ entry.request.stopping_parameters.max_new_tokens + entry.request.stopping_parameters.max_new_tokens
+ self.speculate + self.speculate
- 1; - 1;
tracing::debug!("Allocating {tokens} with {input_ids:?}"); // tracing::debug!("Allocating {tokens} with {input_ids:?}");
let block_allocation = match block_allocator.allocate(tokens, input_ids).await { let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
None => { None => {
@ -322,10 +324,11 @@ impl State {
break 'entry_loop; break 'entry_loop;
} }
Some(mut block_allocation) => { Some(mut block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}"); // tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
if block_allocation.prefix_len == entry.request.input_length { if block_allocation.prefix_len >= entry.request.input_length {
// panic!("Something wrong happened we have overmatched the prefix {} >= {}", block_allocation.prefix_len, entry.request.input_length);
// The whole request was found in the radix trie // The whole request was found in the radix trie
// However, for the transformer forward to work, we need to // However, for the transformer forward to work, we need to
// have at least one token of postfix. // have at least one token of postfix.
@ -336,6 +339,13 @@ impl State {
} }
}; };
let new_slots = &block_allocation.slots[block_allocation.prefix_len as usize..];
for s in new_slots {
let entry = viewed.entry(*s).or_default();
*entry += 1;
assert!(*entry <= 1);
}
let postfix_len = entry.request.input_length - block_allocation.prefix_len; let postfix_len = entry.request.input_length - block_allocation.prefix_len;
if prefill_tokens + postfix_len > prefill_token_budget { if prefill_tokens + postfix_len > prefill_token_budget {
@ -349,6 +359,14 @@ impl State {
} else { } else {
// We cannot prefill even one token for this entry // We cannot prefill even one token for this entry
// Add it back to the queue // Add it back to the queue
// Removing the allocations.
tracing::debug!("Removing some allocations");
for s in new_slots {
let entry = viewed.entry(*s).or_default();
*entry -= 1;
assert!(*entry <= 1);
}
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
} }
tracing::debug!( tracing::debug!(
@ -363,6 +381,12 @@ impl State {
"Over budget: prefill_tokens={} > {prefill_token_budget}", "Over budget: prefill_tokens={} > {prefill_token_budget}",
prefill_tokens + postfix_len prefill_tokens + postfix_len
); );
tracing::debug!("Removing some allocations");
for s in new_slots {
let entry = viewed.entry(*s).or_default();
*entry -= 1;
assert!(*entry <= 1);
}
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
break 'entry_loop; break 'entry_loop;
} }

View File

@ -1,5 +1,6 @@
use crate::block_allocator::{Allocator, BlockAllocation}; use crate::block_allocator::{Allocator, BlockAllocation};
use slotmap::{DefaultKey, SlotMap}; use slotmap::{DefaultKey, SlotMap};
use std::collections::HashSet;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::{ use std::{
collections::{BTreeSet, HashMap}, collections::{BTreeSet, HashMap},
@ -86,9 +87,11 @@ impl Allocator for RadixAllocator {
) -> Option<BlockAllocation> { ) -> Option<BlockAllocation> {
let mut blocks = vec![]; let mut blocks = vec![];
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
let node_id = self let node_id = self.cache_blocks.find(
.cache_blocks // &prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)],
.find(prefill_tokens.as_slice(), &mut blocks); &prefill_tokens.as_slice(),
&mut blocks,
);
node_id node_id
} else { } else {
self.cache_blocks.root_id() self.cache_blocks.root_id()
@ -136,6 +139,26 @@ impl Allocator for RadixAllocator {
slots slots
}; };
tracing::debug!("Allocated {}", self.allocation_id);
let slot_set = slots.iter().collect::<HashSet<_>>();
let mut slot_count: HashMap<u32, usize> = HashMap::new();
for slot in &slots {
let entry = slot_count.entry(*slot).or_default();
*entry += 1;
}
let duplicates: HashMap<u32, usize> =
slot_count.into_iter().filter(|(_k, v)| *v > 1).collect();
// assert_eq!(slots.len(), slot_set.len(), "Duplicates {duplicates:?}");
let free_set = self.free_blocks.iter().collect::<HashSet<_>>();
assert_eq!(
free_set
.intersection(&slot_set)
.collect::<HashSet<_>>()
.len(),
0
);
let allocation = RadixAllocation { let allocation = RadixAllocation {
prefix_node, prefix_node,
cached_prefix_len: prefix_len, cached_prefix_len: prefix_len,
@ -144,6 +167,7 @@ impl Allocator for RadixAllocator {
self.allocation_id += 1; self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation); self.allocations.insert(self.allocation_id, allocation);
tracing::debug!("Allocated {}", self.allocation_id);
Some(BlockAllocation { Some(BlockAllocation {
allocation_id: self.allocation_id, allocation_id: self.allocation_id,
@ -155,7 +179,8 @@ impl Allocator for RadixAllocator {
} }
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) { fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
let allocation = match self.allocations.remove(&allocation_id) { tracing::debug!("Radix free {allocation_id}");
let allocation: RadixAllocation = match self.allocations.remove(&allocation_id) {
Some(allocation) => allocation, Some(allocation) => allocation,
None => unreachable!("Tried to free an unknown allocation."), None => unreachable!("Tried to free an unknown allocation."),
}; };
@ -283,7 +308,7 @@ impl RadixTrie {
} }
/// Find worker. /// Find worker.
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId { fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
let node = &self.nodes[node_id]; let node = &self.nodes[node_id];
if key.len() >= self.block_size { if key.len() >= self.block_size {
@ -295,9 +320,13 @@ impl RadixTrie {
assert_eq!(shared_prefix_len % self.block_size, 0); assert_eq!(shared_prefix_len % self.block_size, 0);
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
// A node represents the prefix of its children. So, only
// recurse when there is a full prefix match.
let key = &key[shared_prefix_len..]; let key = &key[shared_prefix_len..];
if !key.is_empty() { if !key.is_empty() && shared_prefix_len == child.key.len() {
node_id = self.find_(child_id, key, blocks); return self.find_(child_id, key, blocks);
} else {
return child_id;
} }
} }
} }
@ -369,7 +398,6 @@ impl RadixTrie {
while let Some((last_access, node_id)) = self.leaves.pop_first() { while let Some((last_access, node_id)) = self.leaves.pop_first() {
let blocks_needed = n_blocks.saturating_sub(evicted.len()); let blocks_needed = n_blocks.saturating_sub(evicted.len());
tracing::debug!("Evicting node {node_id:?} ");
let node = self.nodes.get(node_id).expect("Leave does not exist"); let node = self.nodes.get(node_id).expect("Leave does not exist");
assert_eq!( assert_eq!(
@ -381,11 +409,8 @@ impl RadixTrie {
if blocks_needed >= node.blocks.len() { if blocks_needed >= node.blocks.len() {
// We need to evict the whole node if we need more blocks than it has. // We need to evict the whole node if we need more blocks than it has.
let node = self.remove_node(node_id); let node = self.remove_node(node_id);
tracing::debug!("Evicted node {node_id:?} got back {}", node.blocks.len());
evicted.extend(node.blocks); evicted.extend(node.blocks);
if evicted.len() >= n_blocks {
break;
}
} else { } else {
// The node has more blocks than needed, so we'll just remove // The node has more blocks than needed, so we'll just remove
// the required number of blocks and leave the remaining blocks // the required number of blocks and leave the remaining blocks
@ -397,6 +422,10 @@ impl RadixTrie {
node.key.truncate(truncate_tokens); node.key.truncate(truncate_tokens);
evicted.extend(node.blocks.split_off(truncate_blocks)); evicted.extend(node.blocks.split_off(truncate_blocks));
self.leaves.insert((last_access, node_id)); self.leaves.insert((last_access, node_id));
tracing::debug!("Evicted partial node {node_id:?} got {blocks_needed} back",);
}
if evicted.len() >= n_blocks {
tracing::debug!("Got enough {}", evicted.len());
break; break;
} }
} }
@ -873,4 +902,76 @@ mod tests {
// Clear out the whole trie. // Clear out the whole trie.
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
} }
#[derive(Clone, Debug)]
enum Command {
Allocate {
tokens: u32,
prefill: Option<Arc<Vec<u32>>>,
},
Free {
blocks: Vec<u32>,
allocation_id: u64,
},
}
#[derive(Clone)]
struct Vocab(u32);
impl Arbitrary for Vocab {
fn arbitrary(gen: &mut Gen) -> Self {
let free = bool::arbitrary(gen);
if free {
Vocab(0)
} else {
Vocab(1)
}
}
}
use quickcheck::quickcheck;
use quickcheck::{Arbitrary, Gen};
impl Arbitrary for Command {
fn arbitrary(gen: &mut Gen) -> Self {
let free = bool::arbitrary(gen);
if free {
let blocks: Vec<u32> = Vec::arbitrary(gen);
let allocation_id = u64::arbitrary(gen);
Command::Free {
blocks,
allocation_id,
}
} else {
let tokens = u32::arbitrary(gen);
let prefill_tokens: Vec<Vocab> = Vec::arbitrary(gen);
let prefill_tokens = prefill_tokens.into_iter().map(|v| v.0).collect();
let prefill = Some(Arc::new(prefill_tokens));
Command::Allocate { tokens, prefill }
}
}
}
quickcheck! {
fn allocator_commands(commands: Vec<Command>) -> bool {
let mut cache = RadixAllocator::new(1, 20, None);
let mut allocations = vec![];
for command in commands{
match command{
Command::Allocate{tokens, prefill} => {
let allocation = cache.allocate(tokens, prefill);
if let Some(allocation) = allocation{
allocations.push(allocation.allocation_id);
}
}
Command::Free{blocks, allocation_id} => {
if allocations.contains(&allocation_id){
cache.free(blocks, allocation_id);
}
}
}
}
true
}
}
} }

View File

@ -505,7 +505,7 @@ async fn generate_stream_internal(
let start_time = Instant::now(); let start_time = Instant::now();
metrics::counter!("tgi_request_count").increment(1); metrics::counter!("tgi_request_count").increment(1);
tracing::debug!("Input: {}", req.inputs); // tracing::debug!("Input: {}", req.inputs);
let compute_characters = req.inputs.chars().count(); let compute_characters = req.inputs.chars().count();

View File

@ -4,6 +4,7 @@ import os
import time import time
import torch import torch
import torch.distributed import torch.distributed
from collections import Counter
import numpy as np import numpy as np
@ -87,6 +88,71 @@ tracer = trace.get_tracer(__name__)
SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW: Optional[int] = None
WARMUP = True
def ASSERT_SIMPLE(batch):
slots = []
for r in batch.requests:
slots.extend(r.slots[r.cache_len :])
assert len(set(slots)) == len(slots)
def ASSERT_BATCH_IS_CORRECT(batch):
global WARMUP
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
# kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths_tensor = batch.input_lengths_tensor
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
assert input_ids.shape == position_ids.shape
# print(input_lengths_tensor, cache_lengths_tensor, slots, block_tables)
assert input_lengths_tensor.shape == cache_lengths_tensor.shape
assert torch.all(cache_lengths_tensor >= 0)
assert torch.all(input_lengths_tensor > 0)
loffset = 0
coffset = 0
assert torch.unique(slots).shape == slots.shape, (
f"Slots {slots} - Cache {cache_lengths_tensor} Input {input_lengths_tensor} - Slto indices {batch.slot_indices} - Counter {Counter(slots.tolist()).most_common(3)} "
)
previous_slots = []
previous_blocks = []
for input_length, cache_length in zip(input_lengths_tensor, cache_lengths_tensor):
slot = slots[loffset : loffset + input_length]
blocks = block_tables[coffset][: input_length + cache_length]
assert len(slot.shape) == 1
# print(f"Blocks {blocks} - Slots {slots}")
assert torch.all(blocks[cache_length : cache_length + input_length] == slot)
if not WARMUP:
assert torch.all(blocks != 0)
assert torch.unique(blocks).shape == blocks.shape
for pblocks in previous_blocks:
m = min(pblocks.shape[0], blocks.shape[0])
diff = pblocks[:m] - blocks[:m]
NZ = diff.nonzero().view(-1)
if NZ.shape[0]:
# Remove the first offset
assert NZ[0] + NZ.shape[0] == m
NZ = NZ - NZ[0]
assert torch.all(NZ >= 0), f"{pblocks} - blocks {blocks} NZ {NZ}"
assert torch.all(NZ == torch.arange(NZ.shape[0], device=NZ.device))
loffset += input_length
coffset += 1
previous_slots.append(slot)
previous_blocks.append(blocks)
# assert cu_seqlen_prefill.shape == position_ids.shape
WARMUP = False
def small_power_of_2(n: int): def small_power_of_2(n: int):
return 1 << ((n - 1).bit_length() - 1) return 1 << ((n - 1).bit_length() - 1)
@ -135,6 +201,11 @@ def init_cpu_threads_env(rank_id: int, world_size: int):
) )
from collections import defaultdict
HISTORY = defaultdict(list)
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
batch_id: int batch_id: int
@ -262,6 +333,7 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "FlashCausalLMBatch": ) -> "FlashCausalLMBatch":
HISTORY[pb.id].append(("TOKENIZED"))
speculate = get_speculate() speculate = get_speculate()
cache_lengths = [] cache_lengths = []
@ -290,6 +362,7 @@ class FlashCausalLMBatch(Batch):
block_tables_ragged = [] block_tables_ragged = []
# Parse batch # Parse batch
viewed = set()
for i, (r, tokenized_input) in enumerate( for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs) zip(pb.requests, batch_tokenized_inputs)
): ):
@ -304,10 +377,16 @@ class FlashCausalLMBatch(Batch):
prompt_lengths.append(prompt_length) prompt_lengths.append(prompt_length)
cache_length = r.cache_len cache_length = r.cache_len
new_slots = r.slots[cache_length:]
if set(new_slots).intersection(viewed):
import ipdb
assert ( ipdb.set_trace()
cache_length <= prompt_length viewed.update(set(new_slots))
), f"Prefix {cache_length} vs input {prompt_length}"
assert cache_length <= prompt_length, (
f"Prefix {cache_length} vs input {prompt_length}"
)
if cache_length == prompt_length: if cache_length == prompt_length:
assert False, "unreachable" assert False, "unreachable"
@ -325,9 +404,9 @@ class FlashCausalLMBatch(Batch):
postfix_ids = tokenized_input[ postfix_ids = tokenized_input[
cache_length : cache_length + input_length cache_length : cache_length + input_length
] ]
assert ( assert len(postfix_ids) == input_length, (
len(postfix_ids) == input_length "Rust and Python tokenizers are not aligned"
), "Rust and Python tokenizers are not aligned" )
else: else:
# Use all the remaining ids # Use all the remaining ids
postfix_ids = tokenized_input[cache_length:] postfix_ids = tokenized_input[cache_length:]
@ -378,6 +457,7 @@ class FlashCausalLMBatch(Batch):
cu_blocks.append(len(block_tables_ragged)) cu_blocks.append(len(block_tables_ragged))
slots.extend(request_slots) slots.extend(request_slots)
cu_slots.append(len(slots)) cu_slots.append(len(slots))
cache_lengths.append(cache_length) cache_lengths.append(cache_length)
@ -392,6 +472,25 @@ class FlashCausalLMBatch(Batch):
prompt_length + max_new_tokens + speculative_length, prompt_length + max_new_tokens + speculative_length,
) )
# offset = 0
# new_slots = []
# total_slots = []
# for cache_length, input_length in zip(cache_lengths, input_lengths):
# new_slots_ = slots[
# offset + cache_length : offset + cache_length + input_length
# ]
# offset += cache_length + input_length
# new_slots.extend(new_slots_)
# total_slots.append(new_slots_)
# if new_slots:
# if Counter(new_slots).most_common(1)[0][1] != 1:
# import ipdb
# ipdb.set_trace()
# assert Counter(new_slots).most_common(1)[0][1] == 1, (
# f"New slots {new_slots}"
# )
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer next_token_chooser_parameters, dtype, device, tokenizer
) )
@ -496,6 +595,7 @@ class FlashCausalLMBatch(Batch):
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
HISTORY[self.batch_id].append(("FILTER", request_ids))
if len(request_ids) == 0: if len(request_ids) == 0:
raise ValueError("Batch must have at least one request") raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same # We assume that if len(requests) == len(self) then the requests are the same
@ -702,6 +802,7 @@ class FlashCausalLMBatch(Batch):
@classmethod @classmethod
@tracer.start_as_current_span("concatenate") @tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
HISTORY[batches[0].batch_id].append(("CONCATENATE", batches))
# Batch attributes # Batch attributes
requests = [] requests = []
requests_idx_mapping = {} requests_idx_mapping = {}
@ -884,6 +985,15 @@ class FlashCausalLMBatch(Batch):
cumulative_slots += len(batch.slots) cumulative_slots += len(batch.slots)
cumulative_batch_size += len(batch) cumulative_batch_size += len(batch)
if slot_indices:
new_slots = slots[slot_indices]
import ipdb
ipdb.set_trace()
assert torch.unique(new_slots).shape == new_slots.shape, (
f"Slots {new_slots} - Cache {cache_lengths_tensor} Input {input_lengths_tensor} - Slto indices {slot_indices} - Counter {Counter(new_slots.tolist()).most_common(3)} "
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype, dtype=batches[0].next_token_chooser.dtype,
@ -991,6 +1101,16 @@ class FlashCausalLMBatch(Batch):
cu_slots_gpu, cu_slots_gpu,
self.position_ids, self.position_ids,
self.slot_indices, self.slot_indices,
self.slots,
)
SLOTS = self.slots[self.slot_indices]
most_common = Counter(SLOTS.view(-1).tolist()).most_common(3)
if torch.unique(SLOTS.view(-1)).shape != SLOTS.view(-1).shape:
import ipdb
ipdb.set_trace()
assert torch.unique(SLOTS.view(-1)).shape == SLOTS.view(-1).shape, (
f"Slots {self.slots.view(-1)} Indices {self.slot_indices} - COUNTER {most_common} - Diff {self.slots == most_common[0][0]}"
) )
sliding_window = get_sliding_windows() sliding_window = get_sliding_windows()
@ -1813,7 +1933,7 @@ class FlashCausalLM(Model):
kv_cache = self.kv_cache kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths_tensor = batch.input_lengths_tensor
cache_lengths_tensor = batch.cache_lengths_tensor cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
@ -1832,6 +1952,8 @@ class FlashCausalLM(Model):
else: else:
cuda_graph = None cuda_graph = None
ASSERT_BATCH_IS_CORRECT(batch)
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged( block_tables = block_tables_to_ragged(
@ -1845,11 +1967,11 @@ class FlashCausalLM(Model):
with self._forward_context( with self._forward_context(
block_tables=block_tables, block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths, input_lengths_tensor=input_lengths_tensor,
cache_lengths_tensor=cache_lengths_tensor, cache_lengths_tensor=cache_lengths_tensor,
): ):
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths_tensor,
cache_lengths=cache_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
max_q=batch.max_input_length, max_q=batch.max_input_length,
@ -1897,11 +2019,13 @@ class FlashCausalLM(Model):
cuda_graph["slots"].fill_(0) cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths cuda_graph["input_lengths"][: input_lengths_tensor.shape[0]] = (
input_lengths_tensor
)
cuda_graph["cache_lengths"].zero_() cuda_graph["cache_lengths"].zero_()
cuda_graph["cache_lengths"][ cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
: cache_lengths_tensor.shape[0] cache_lengths_tensor
] = cache_lengths_tensor )
with self._forward_context( with self._forward_context(
block_tables=cuda_graph["block_tables"], block_tables=cuda_graph["block_tables"],

View File

@ -2,6 +2,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from collections import Counter
from loguru import logger from loguru import logger
from typing import List, Optional from typing import List, Optional
@ -130,6 +131,7 @@ def prepare_position_slot_ids(
cu_slots: torch.Tensor, cu_slots: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
slot_indices: torch.Tensor, slot_indices: torch.Tensor,
slots: torch.Tensor,
): ):
def grid(meta): def grid(meta):
return ( return (
@ -140,6 +142,12 @@ def prepare_position_slot_ids(
triton_prepare_position_slot_ids[grid]( triton_prepare_position_slot_ids[grid](
cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256 cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256
) )
SLOTS = slots[slot_indices]
most_common = Counter(SLOTS.view(-1).tolist()).most_common(3)
if torch.unique(SLOTS.view(-1)).shape != SLOTS.view(-1).shape:
import ipdb
ipdb.set_trace()
def slots_filtering( def slots_filtering(
@ -158,6 +166,10 @@ def slots_filtering(
triton_slots_filtering[grid]( triton_slots_filtering[grid](
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256 slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
) )
assert torch.all(slots[slots_start] == filtered_slots[cu_slots[:-1]])
# assert torch.unique(slots).shape == slots.shape, (
# f"Slots {slots} {Counter(slots.tolist()).most_common(3)}"
# )
@triton.jit @triton.jit

View File

@ -11,6 +11,10 @@ from grpc_reflection.v1alpha import reflection
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
from text_generation_server.models.flash_causal_lm import (
ASSERT_BATCH_IS_CORRECT,
ASSERT_SIMPLE,
)
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.models import Model, get_model_with_lora_adapters
@ -164,6 +168,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.model.device, self.model.device,
) )
else: else:
ASSERT_SIMPLE(request.batch)
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
@ -178,6 +183,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
) )
start_concat = time.time_ns() start_concat = time.time_ns()
batch = self.model.batch_type.concatenate([cached_batch, batch]) batch = self.model.batch_type.concatenate([cached_batch, batch])
# ASSERT_BATCH_IS_CORRECT(batch)
concat_ns = time.time_ns() - start_concat concat_ns = time.time_ns() - start_concat
generations, next_batch, timings = self.model.generate_token(batch) generations, next_batch, timings = self.model.generate_token(batch)