All the assertions.

Invariants added

Remove the logs.
This commit is contained in:
Nicolas Patry 2025-02-28 17:41:22 +01:00
parent 463228ebfc
commit ddf0b02240
No known key found for this signature in database
GPG Key ID: 4242CEF24CB6DBF9
12 changed files with 394 additions and 30 deletions

22
Cargo.lock generated
View File

@ -1169,6 +1169,16 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "env_logger"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
dependencies = [
"log",
"regex",
]
[[package]]
name = "equivalent"
version = "1.0.1"
@ -3403,6 +3413,17 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
[[package]]
name = "quickcheck"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6"
dependencies = [
"env_logger",
"log",
"rand",
]
[[package]]
name = "quote"
version = "1.0.37"
@ -4630,6 +4651,7 @@ dependencies = [
"opentelemetry-otlp",
"prost 0.12.6",
"prost-build",
"quickcheck",
"rand",
"regex",
"reqwest 0.11.27",

View File

@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
use pb::generate::v3::*;
use std::cmp::min;
use std::io::Error;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
@ -232,6 +233,20 @@ impl Client {
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let slots: Vec<_> = batch
.requests
.iter()
.map(|r| &r.slots[r.cache_len as usize..])
.flatten()
.collect();
assert_eq!(
slots.len(),
slots.iter().collect::<std::collections::HashSet<_>>().len()
);
if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
std::process::exit(1);
}
let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,

View File

@ -63,6 +63,7 @@ base64 = { workspace = true }
prost = "^0.12"
tonic = "^0.10"
tower = "^0.4"
quickcheck = "1.0.3"
[build-dependencies]
tonic-build = "0.10.1"

View File

@ -5,6 +5,7 @@ use crate::client::{
use crate::queue::{Entry, Queue};
use async_trait::async_trait;
use nohash_hasher::IntMap;
use std::collections::HashMap;
use std::sync::Arc;
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::ValidGenerateRequest;
@ -13,6 +14,7 @@ use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::IntoRequest;
use tracing::{info_span, instrument, Instrument, Span};
pub struct BackendV3 {
@ -121,6 +123,35 @@ impl Backend for BackendV3 {
}
}
impl Batch {
pub fn check(&self) -> Result<(), InferError> {
let slots: Vec<_> = self
.requests
.iter()
.map(|r| &r.slots[r.cache_len as usize..])
.flatten()
.collect();
// assert_eq!(
// slots.len(),
// slots.iter().collect::<std::collections::HashSet<_>>().len()
// );
if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
let mut map: HashMap<u32, usize> = HashMap::new();
for slot in slots {
*map.entry(*slot).or_default() += 1usize;
}
let duplicates: HashMap<_, _> = map.into_iter().filter(|(_slot, c)| *c > 1).collect();
Err(InferError::GenerationError(format!(
"Invalid batch: {duplicates:?}",
)))
} else {
Ok(())
}
}
}
/// Batching logic
/// Will be launched in a background Tokio task
///
@ -154,6 +185,7 @@ pub(crate) async fn batching_task(
)
.await
{
batch.check().unwrap();
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
.instrument(span)
.await;
@ -205,6 +237,7 @@ pub(crate) async fn batching_task(
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
.await
{
new_batch.check().unwrap();
// Tracking metrics
if min_size.is_some() {
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
@ -225,6 +258,7 @@ pub(crate) async fn batching_task(
// concatenated during the prefill op server side
entries.extend(new_entries);
// Generate one token for both the cached batch and the new batch
new_batch.check().unwrap();
let new_cached_batch =
prefill(&mut client, new_batch, cached_batch, &mut entries)
.instrument(span)
@ -249,6 +283,7 @@ pub(crate) async fn batching_task(
});
// Generate one token for this new batch to have the attention past in cache
new_batch.check().unwrap();
let new_cached_batch =
prefill(&mut client, new_batch, None, &mut new_entries)
.instrument(span)

View File

@ -19,7 +19,13 @@ pub struct BlockAllocation {
impl Drop for BlockAllocation {
fn drop(&mut self) {
if let Some(block_allocator) = self.block_allocator.as_mut() {
tracing::debug!("Freeing block {}", self.allocation_id);
block_allocator.free(self.blocks.clone(), self.allocation_id)
} else {
#[cfg(not(test))]
{
panic!("We didn't have a block allocator");
}
}
}
}

View File

@ -1,6 +1,7 @@
/// Single shard Client
use crate::client::{pb, Chunk};
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
use axum::http::Error;
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use grpc_metadata::InjectTelemetryContext;
@ -232,11 +233,28 @@ impl Client {
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let slots: Vec<_> = batch
.requests
.iter()
.map(|r| &r.slots[r.cache_len as usize..])
.flatten()
.collect();
assert_eq!(
slots.len(),
slots.iter().collect::<std::collections::HashSet<_>>().len()
);
if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
std::process::exit(1);
}
let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
// if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
// return Err(Error::from("Test"));
// }
let response = self.stub.prefill(request).await?.into_inner();
Ok((
response.generations,

View File

@ -5,7 +5,7 @@ use crate::client::{
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use std::cmp::max;
use std::collections::VecDeque;
use std::collections::{HashMap, VecDeque};
use text_generation_router::infer::InferError;
use text_generation_router::infer::InferStreamResponse;
use text_generation_router::validation::{
@ -269,6 +269,8 @@ impl State {
let mut decode_tokens: u32 = 0;
let mut max_blocks = 0;
let mut viewed: HashMap<u32, usize> = HashMap::new();
// Pop entries starting from the front of the queue
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
// Filter entries where the response receiver was dropped (== entries where the request
@ -311,7 +313,7 @@ impl State {
+ entry.request.stopping_parameters.max_new_tokens
+ self.speculate
- 1;
tracing::debug!("Allocating {tokens} with {input_ids:?}");
// tracing::debug!("Allocating {tokens} with {input_ids:?}");
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
None => {
@ -322,10 +324,11 @@ impl State {
break 'entry_loop;
}
Some(mut block_allocation) => {
tracing::debug!("Allocation: {block_allocation:?}");
// tracing::debug!("Allocation: {block_allocation:?}");
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
if block_allocation.prefix_len == entry.request.input_length {
if block_allocation.prefix_len >= entry.request.input_length {
// panic!("Something wrong happened we have overmatched the prefix {} >= {}", block_allocation.prefix_len, entry.request.input_length);
// The whole request was found in the radix trie
// However, for the transformer forward to work, we need to
// have at least one token of postfix.
@ -336,6 +339,13 @@ impl State {
}
};
let new_slots = &block_allocation.slots[block_allocation.prefix_len as usize..];
for s in new_slots {
let entry = viewed.entry(*s).or_default();
*entry += 1;
assert!(*entry <= 1);
}
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
if prefill_tokens + postfix_len > prefill_token_budget {
@ -349,6 +359,14 @@ impl State {
} else {
// We cannot prefill even one token for this entry
// Add it back to the queue
// Removing the allocations.
tracing::debug!("Removing some allocations");
for s in new_slots {
let entry = viewed.entry(*s).or_default();
*entry -= 1;
assert!(*entry <= 1);
}
self.entries.push_front((id, entry));
}
tracing::debug!(
@ -363,6 +381,12 @@ impl State {
"Over budget: prefill_tokens={} > {prefill_token_budget}",
prefill_tokens + postfix_len
);
tracing::debug!("Removing some allocations");
for s in new_slots {
let entry = viewed.entry(*s).or_default();
*entry -= 1;
assert!(*entry <= 1);
}
self.entries.push_front((id, entry));
break 'entry_loop;
}

View File

@ -1,5 +1,6 @@
use crate::block_allocator::{Allocator, BlockAllocation};
use slotmap::{DefaultKey, SlotMap};
use std::collections::HashSet;
use std::hash::{Hash, Hasher};
use std::{
collections::{BTreeSet, HashMap},
@ -86,9 +87,11 @@ impl Allocator for RadixAllocator {
) -> Option<BlockAllocation> {
let mut blocks = vec![];
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
let node_id = self
.cache_blocks
.find(prefill_tokens.as_slice(), &mut blocks);
let node_id = self.cache_blocks.find(
// &prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)],
&prefill_tokens.as_slice(),
&mut blocks,
);
node_id
} else {
self.cache_blocks.root_id()
@ -136,6 +139,26 @@ impl Allocator for RadixAllocator {
slots
};
tracing::debug!("Allocated {}", self.allocation_id);
let slot_set = slots.iter().collect::<HashSet<_>>();
let mut slot_count: HashMap<u32, usize> = HashMap::new();
for slot in &slots {
let entry = slot_count.entry(*slot).or_default();
*entry += 1;
}
let duplicates: HashMap<u32, usize> =
slot_count.into_iter().filter(|(_k, v)| *v > 1).collect();
// assert_eq!(slots.len(), slot_set.len(), "Duplicates {duplicates:?}");
let free_set = self.free_blocks.iter().collect::<HashSet<_>>();
assert_eq!(
free_set
.intersection(&slot_set)
.collect::<HashSet<_>>()
.len(),
0
);
let allocation = RadixAllocation {
prefix_node,
cached_prefix_len: prefix_len,
@ -144,6 +167,7 @@ impl Allocator for RadixAllocator {
self.allocation_id += 1;
self.allocations.insert(self.allocation_id, allocation);
tracing::debug!("Allocated {}", self.allocation_id);
Some(BlockAllocation {
allocation_id: self.allocation_id,
@ -155,7 +179,8 @@ impl Allocator for RadixAllocator {
}
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
let allocation = match self.allocations.remove(&allocation_id) {
tracing::debug!("Radix free {allocation_id}");
let allocation: RadixAllocation = match self.allocations.remove(&allocation_id) {
Some(allocation) => allocation,
None => unreachable!("Tried to free an unknown allocation."),
};
@ -283,7 +308,7 @@ impl RadixTrie {
}
/// Find worker.
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
let node = &self.nodes[node_id];
if key.len() >= self.block_size {
@ -295,9 +320,13 @@ impl RadixTrie {
assert_eq!(shared_prefix_len % self.block_size, 0);
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
// A node represents the prefix of its children. So, only
// recurse when there is a full prefix match.
let key = &key[shared_prefix_len..];
if !key.is_empty() {
node_id = self.find_(child_id, key, blocks);
if !key.is_empty() && shared_prefix_len == child.key.len() {
return self.find_(child_id, key, blocks);
} else {
return child_id;
}
}
}
@ -369,7 +398,6 @@ impl RadixTrie {
while let Some((last_access, node_id)) = self.leaves.pop_first() {
let blocks_needed = n_blocks.saturating_sub(evicted.len());
tracing::debug!("Evicting node {node_id:?} ");
let node = self.nodes.get(node_id).expect("Leave does not exist");
assert_eq!(
@ -381,11 +409,8 @@ impl RadixTrie {
if blocks_needed >= node.blocks.len() {
// We need to evict the whole node if we need more blocks than it has.
let node = self.remove_node(node_id);
tracing::debug!("Evicted node {node_id:?} got back {}", node.blocks.len());
evicted.extend(node.blocks);
if evicted.len() >= n_blocks {
break;
}
} else {
// The node has more blocks than needed, so we'll just remove
// the required number of blocks and leave the remaining blocks
@ -397,6 +422,10 @@ impl RadixTrie {
node.key.truncate(truncate_tokens);
evicted.extend(node.blocks.split_off(truncate_blocks));
self.leaves.insert((last_access, node_id));
tracing::debug!("Evicted partial node {node_id:?} got {blocks_needed} back",);
}
if evicted.len() >= n_blocks {
tracing::debug!("Got enough {}", evicted.len());
break;
}
}
@ -873,4 +902,76 @@ mod tests {
// Clear out the whole trie.
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
}
#[derive(Clone, Debug)]
enum Command {
Allocate {
tokens: u32,
prefill: Option<Arc<Vec<u32>>>,
},
Free {
blocks: Vec<u32>,
allocation_id: u64,
},
}
#[derive(Clone)]
struct Vocab(u32);
impl Arbitrary for Vocab {
fn arbitrary(gen: &mut Gen) -> Self {
let free = bool::arbitrary(gen);
if free {
Vocab(0)
} else {
Vocab(1)
}
}
}
use quickcheck::quickcheck;
use quickcheck::{Arbitrary, Gen};
impl Arbitrary for Command {
fn arbitrary(gen: &mut Gen) -> Self {
let free = bool::arbitrary(gen);
if free {
let blocks: Vec<u32> = Vec::arbitrary(gen);
let allocation_id = u64::arbitrary(gen);
Command::Free {
blocks,
allocation_id,
}
} else {
let tokens = u32::arbitrary(gen);
let prefill_tokens: Vec<Vocab> = Vec::arbitrary(gen);
let prefill_tokens = prefill_tokens.into_iter().map(|v| v.0).collect();
let prefill = Some(Arc::new(prefill_tokens));
Command::Allocate { tokens, prefill }
}
}
}
quickcheck! {
fn allocator_commands(commands: Vec<Command>) -> bool {
let mut cache = RadixAllocator::new(1, 20, None);
let mut allocations = vec![];
for command in commands{
match command{
Command::Allocate{tokens, prefill} => {
let allocation = cache.allocate(tokens, prefill);
if let Some(allocation) = allocation{
allocations.push(allocation.allocation_id);
}
}
Command::Free{blocks, allocation_id} => {
if allocations.contains(&allocation_id){
cache.free(blocks, allocation_id);
}
}
}
}
true
}
}
}

View File

@ -505,7 +505,7 @@ async fn generate_stream_internal(
let start_time = Instant::now();
metrics::counter!("tgi_request_count").increment(1);
tracing::debug!("Input: {}", req.inputs);
// tracing::debug!("Input: {}", req.inputs);
let compute_characters = req.inputs.chars().count();

View File

@ -4,6 +4,7 @@ import os
import time
import torch
import torch.distributed
from collections import Counter
import numpy as np
@ -87,6 +88,71 @@ tracer = trace.get_tracer(__name__)
SLIDING_WINDOW: Optional[int] = None
WARMUP = True
def ASSERT_SIMPLE(batch):
slots = []
for r in batch.requests:
slots.extend(r.slots[r.cache_len :])
assert len(set(slots)) == len(slots)
def ASSERT_BATCH_IS_CORRECT(batch):
global WARMUP
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
# kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths_tensor = batch.input_lengths_tensor
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
assert input_ids.shape == position_ids.shape
# print(input_lengths_tensor, cache_lengths_tensor, slots, block_tables)
assert input_lengths_tensor.shape == cache_lengths_tensor.shape
assert torch.all(cache_lengths_tensor >= 0)
assert torch.all(input_lengths_tensor > 0)
loffset = 0
coffset = 0
assert torch.unique(slots).shape == slots.shape, (
f"Slots {slots} - Cache {cache_lengths_tensor} Input {input_lengths_tensor} - Slto indices {batch.slot_indices} - Counter {Counter(slots.tolist()).most_common(3)} "
)
previous_slots = []
previous_blocks = []
for input_length, cache_length in zip(input_lengths_tensor, cache_lengths_tensor):
slot = slots[loffset : loffset + input_length]
blocks = block_tables[coffset][: input_length + cache_length]
assert len(slot.shape) == 1
# print(f"Blocks {blocks} - Slots {slots}")
assert torch.all(blocks[cache_length : cache_length + input_length] == slot)
if not WARMUP:
assert torch.all(blocks != 0)
assert torch.unique(blocks).shape == blocks.shape
for pblocks in previous_blocks:
m = min(pblocks.shape[0], blocks.shape[0])
diff = pblocks[:m] - blocks[:m]
NZ = diff.nonzero().view(-1)
if NZ.shape[0]:
# Remove the first offset
assert NZ[0] + NZ.shape[0] == m
NZ = NZ - NZ[0]
assert torch.all(NZ >= 0), f"{pblocks} - blocks {blocks} NZ {NZ}"
assert torch.all(NZ == torch.arange(NZ.shape[0], device=NZ.device))
loffset += input_length
coffset += 1
previous_slots.append(slot)
previous_blocks.append(blocks)
# assert cu_seqlen_prefill.shape == position_ids.shape
WARMUP = False
def small_power_of_2(n: int):
return 1 << ((n - 1).bit_length() - 1)
@ -135,6 +201,11 @@ def init_cpu_threads_env(rank_id: int, world_size: int):
)
from collections import defaultdict
HISTORY = defaultdict(list)
@dataclass
class FlashCausalLMBatch(Batch):
batch_id: int
@ -262,6 +333,7 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
HISTORY[pb.id].append(("TOKENIZED"))
speculate = get_speculate()
cache_lengths = []
@ -290,6 +362,7 @@ class FlashCausalLMBatch(Batch):
block_tables_ragged = []
# Parse batch
viewed = set()
for i, (r, tokenized_input) in enumerate(
zip(pb.requests, batch_tokenized_inputs)
):
@ -304,10 +377,16 @@ class FlashCausalLMBatch(Batch):
prompt_lengths.append(prompt_length)
cache_length = r.cache_len
new_slots = r.slots[cache_length:]
if set(new_slots).intersection(viewed):
import ipdb
assert (
cache_length <= prompt_length
), f"Prefix {cache_length} vs input {prompt_length}"
ipdb.set_trace()
viewed.update(set(new_slots))
assert cache_length <= prompt_length, (
f"Prefix {cache_length} vs input {prompt_length}"
)
if cache_length == prompt_length:
assert False, "unreachable"
@ -325,9 +404,9 @@ class FlashCausalLMBatch(Batch):
postfix_ids = tokenized_input[
cache_length : cache_length + input_length
]
assert (
len(postfix_ids) == input_length
), "Rust and Python tokenizers are not aligned"
assert len(postfix_ids) == input_length, (
"Rust and Python tokenizers are not aligned"
)
else:
# Use all the remaining ids
postfix_ids = tokenized_input[cache_length:]
@ -378,6 +457,7 @@ class FlashCausalLMBatch(Batch):
cu_blocks.append(len(block_tables_ragged))
slots.extend(request_slots)
cu_slots.append(len(slots))
cache_lengths.append(cache_length)
@ -392,6 +472,25 @@ class FlashCausalLMBatch(Batch):
prompt_length + max_new_tokens + speculative_length,
)
# offset = 0
# new_slots = []
# total_slots = []
# for cache_length, input_length in zip(cache_lengths, input_lengths):
# new_slots_ = slots[
# offset + cache_length : offset + cache_length + input_length
# ]
# offset += cache_length + input_length
# new_slots.extend(new_slots_)
# total_slots.append(new_slots_)
# if new_slots:
# if Counter(new_slots).most_common(1)[0][1] != 1:
# import ipdb
# ipdb.set_trace()
# assert Counter(new_slots).most_common(1)[0][1] == 1, (
# f"New slots {new_slots}"
# )
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device, tokenizer
)
@ -496,6 +595,7 @@ class FlashCausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
HISTORY[self.batch_id].append(("FILTER", request_ids))
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
@ -702,6 +802,7 @@ class FlashCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
HISTORY[batches[0].batch_id].append(("CONCATENATE", batches))
# Batch attributes
requests = []
requests_idx_mapping = {}
@ -884,6 +985,15 @@ class FlashCausalLMBatch(Batch):
cumulative_slots += len(batch.slots)
cumulative_batch_size += len(batch)
if slot_indices:
new_slots = slots[slot_indices]
import ipdb
ipdb.set_trace()
assert torch.unique(new_slots).shape == new_slots.shape, (
f"Slots {new_slots} - Cache {cache_lengths_tensor} Input {input_lengths_tensor} - Slto indices {slot_indices} - Counter {Counter(new_slots.tolist()).most_common(3)} "
)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
@ -991,7 +1101,17 @@ class FlashCausalLMBatch(Batch):
cu_slots_gpu,
self.position_ids,
self.slot_indices,
self.slots,
)
SLOTS = self.slots[self.slot_indices]
most_common = Counter(SLOTS.view(-1).tolist()).most_common(3)
if torch.unique(SLOTS.view(-1)).shape != SLOTS.view(-1).shape:
import ipdb
ipdb.set_trace()
assert torch.unique(SLOTS.view(-1)).shape == SLOTS.view(-1).shape, (
f"Slots {self.slots.view(-1)} Indices {self.slot_indices} - COUNTER {most_common} - Diff {self.slots == most_common[0][0]}"
)
sliding_window = get_sliding_windows()
position_ids = []
@ -1813,7 +1933,7 @@ class FlashCausalLM(Model):
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
input_lengths_tensor = batch.input_lengths_tensor
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
@ -1832,6 +1952,8 @@ class FlashCausalLM(Model):
else:
cuda_graph = None
ASSERT_BATCH_IS_CORRECT(batch)
if cu_seqlen_prefill is not None or cuda_graph is None:
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
@ -1845,11 +1967,11 @@ class FlashCausalLM(Model):
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths_tensor=cache_lengths_tensor,
):
seqlen = Seqlen(
input_lengths=input_lengths,
input_lengths=input_lengths_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=batch.max_input_length,
@ -1897,11 +2019,13 @@ class FlashCausalLM(Model):
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
cuda_graph["input_lengths"][: input_lengths_tensor.shape[0]] = (
input_lengths_tensor
)
cuda_graph["cache_lengths"].zero_()
cuda_graph["cache_lengths"][
: cache_lengths_tensor.shape[0]
] = cache_lengths_tensor
cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
cache_lengths_tensor
)
with self._forward_context(
block_tables=cuda_graph["block_tables"],

View File

@ -2,6 +2,7 @@ import torch
import triton
import triton.language as tl
from collections import Counter
from loguru import logger
from typing import List, Optional
@ -130,6 +131,7 @@ def prepare_position_slot_ids(
cu_slots: torch.Tensor,
position_ids: torch.Tensor,
slot_indices: torch.Tensor,
slots: torch.Tensor,
):
def grid(meta):
return (
@ -140,6 +142,12 @@ def prepare_position_slot_ids(
triton_prepare_position_slot_ids[grid](
cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256
)
SLOTS = slots[slot_indices]
most_common = Counter(SLOTS.view(-1).tolist()).most_common(3)
if torch.unique(SLOTS.view(-1)).shape != SLOTS.view(-1).shape:
import ipdb
ipdb.set_trace()
def slots_filtering(
@ -158,6 +166,10 @@ def slots_filtering(
triton_slots_filtering[grid](
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
)
assert torch.all(slots[slots_start] == filtered_slots[cu_slots[:-1]])
# assert torch.unique(slots).shape == slots.shape, (
# f"Slots {slots} {Counter(slots.tolist()).most_common(3)}"
# )
@triton.jit

View File

@ -11,6 +11,10 @@ from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import List, Optional
from text_generation_server.models.flash_causal_lm import (
ASSERT_BATCH_IS_CORRECT,
ASSERT_SIMPLE,
)
from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model_with_lora_adapters
@ -164,6 +168,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.model.device,
)
else:
ASSERT_SIMPLE(request.batch)
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
@ -178,6 +183,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
)
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate([cached_batch, batch])
# ASSERT_BATCH_IS_CORRECT(batch)
concat_ns = time.time_ns() - start_concat
generations, next_batch, timings = self.model.generate_token(batch)