diff --git a/Cargo.lock b/Cargo.lock index 62f9036c..248ab89f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1169,6 +1169,16 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "env_logger" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3" +dependencies = [ + "log", + "regex", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -3403,6 +3413,17 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" +[[package]] +name = "quickcheck" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6" +dependencies = [ + "env_logger", + "log", + "rand", +] + [[package]] name = "quote" version = "1.0.37" @@ -4630,6 +4651,7 @@ dependencies = [ "opentelemetry-otlp", "prost 0.12.6", "prost-build", + "quickcheck", "rand", "regex", "reqwest 0.11.27", diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index 968c1f45..2edbd903 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext; use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient; use pb::generate::v3::*; use std::cmp::min; +use std::io::Error; use std::time::Duration; use tonic::transport::{Channel, Uri}; use tracing::instrument; @@ -232,6 +233,20 @@ impl Client { batch: Batch, cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { + let slots: Vec<_> = batch + .requests + .iter() + .map(|r| &r.slots[r.cache_len as usize..]) + .flatten() + .collect(); + + assert_eq!( + slots.len(), + slots.iter().collect::>().len() + ); + if slots.len() != slots.iter().collect::>().len() { + std::process::exit(1); + } let request = tonic::Request::new(PrefillRequest { batch: Some(batch), cached_batch, diff --git a/backends/v3/Cargo.toml b/backends/v3/Cargo.toml index 996290ed..f56ce7d1 100644 --- a/backends/v3/Cargo.toml +++ b/backends/v3/Cargo.toml @@ -63,6 +63,7 @@ base64 = { workspace = true } prost = "^0.12" tonic = "^0.10" tower = "^0.4" +quickcheck = "1.0.3" [build-dependencies] tonic-build = "0.10.1" diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index 98e8d76f..9ce038c6 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -5,6 +5,7 @@ use crate::client::{ use crate::queue::{Entry, Queue}; use async_trait::async_trait; use nohash_hasher::IntMap; +use std::collections::HashMap; use std::sync::Arc; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::ValidGenerateRequest; @@ -13,6 +14,7 @@ use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Notify}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::IntoRequest; use tracing::{info_span, instrument, Instrument, Span}; pub struct BackendV3 { @@ -121,6 +123,35 @@ impl Backend for BackendV3 { } } +impl Batch { + pub fn check(&self) -> Result<(), InferError> { + let slots: Vec<_> = self + .requests + .iter() + .map(|r| &r.slots[r.cache_len as usize..]) + .flatten() + .collect(); + + // assert_eq!( + // slots.len(), + // slots.iter().collect::>().len() + // ); + if slots.len() != slots.iter().collect::>().len() { + let mut map: HashMap = HashMap::new(); + for slot in slots { + *map.entry(*slot).or_default() += 1usize; + } + let duplicates: HashMap<_, _> = map.into_iter().filter(|(_slot, c)| *c > 1).collect(); + + Err(InferError::GenerationError(format!( + "Invalid batch: {duplicates:?}", + ))) + } else { + Ok(()) + } + } +} + /// Batching logic /// Will be launched in a background Tokio task /// @@ -154,6 +185,7 @@ pub(crate) async fn batching_task( ) .await { + batch.check().unwrap(); let mut cached_batch = prefill(&mut client, batch, None, &mut entries) .instrument(span) .await; @@ -205,6 +237,7 @@ pub(crate) async fn batching_task( .next_batch(min_size, max_size, prefill_token_budget, token_budget) .await { + new_batch.check().unwrap(); // Tracking metrics if min_size.is_some() { metrics::counter!("tgi_batch_concat", "reason" => "backpressure") @@ -225,6 +258,7 @@ pub(crate) async fn batching_task( // concatenated during the prefill op server side entries.extend(new_entries); // Generate one token for both the cached batch and the new batch + new_batch.check().unwrap(); let new_cached_batch = prefill(&mut client, new_batch, cached_batch, &mut entries) .instrument(span) @@ -249,6 +283,7 @@ pub(crate) async fn batching_task( }); // Generate one token for this new batch to have the attention past in cache + new_batch.check().unwrap(); let new_cached_batch = prefill(&mut client, new_batch, None, &mut new_entries) .instrument(span) diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index e7f3d85a..46776049 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -19,7 +19,13 @@ pub struct BlockAllocation { impl Drop for BlockAllocation { fn drop(&mut self) { if let Some(block_allocator) = self.block_allocator.as_mut() { + tracing::debug!("Freeing block {}", self.allocation_id); block_allocator.free(self.blocks.clone(), self.allocation_id) + } else { + #[cfg(not(test))] + { + panic!("We didn't have a block allocator"); + } } } } diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index f4942f64..35751a1e 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -1,6 +1,7 @@ /// Single shard Client use crate::client::{pb, Chunk}; use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64}; +use axum::http::Error; use base64::engine::general_purpose::STANDARD; use base64::Engine; use grpc_metadata::InjectTelemetryContext; @@ -232,11 +233,28 @@ impl Client { batch: Batch, cached_batch: Option, ) -> Result<(Vec, Option, PrefillTimings)> { + let slots: Vec<_> = batch + .requests + .iter() + .map(|r| &r.slots[r.cache_len as usize..]) + .flatten() + .collect(); + + assert_eq!( + slots.len(), + slots.iter().collect::>().len() + ); + if slots.len() != slots.iter().collect::>().len() { + std::process::exit(1); + } let request = tonic::Request::new(PrefillRequest { batch: Some(batch), cached_batch, }) .inject_context(); + // if slots.len() != slots.iter().collect::>().len() { + // return Err(Error::from("Test")); + // } let response = self.stub.prefill(request).await?.into_inner(); Ok(( response.generations, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index 249eebf7..739006b0 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -5,7 +5,7 @@ use crate::client::{ }; use nohash_hasher::{BuildNoHashHasher, IntMap}; use std::cmp::max; -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use text_generation_router::infer::InferError; use text_generation_router::infer::InferStreamResponse; use text_generation_router::validation::{ @@ -269,6 +269,8 @@ impl State { let mut decode_tokens: u32 = 0; let mut max_blocks = 0; + let mut viewed: HashMap = HashMap::new(); + // Pop entries starting from the front of the queue 'entry_loop: while let Some((id, entry)) = self.entries.pop_front() { // Filter entries where the response receiver was dropped (== entries where the request @@ -311,7 +313,7 @@ impl State { + entry.request.stopping_parameters.max_new_tokens + self.speculate - 1; - tracing::debug!("Allocating {tokens} with {input_ids:?}"); + // tracing::debug!("Allocating {tokens} with {input_ids:?}"); let block_allocation = match block_allocator.allocate(tokens, input_ids).await { None => { @@ -322,10 +324,11 @@ impl State { break 'entry_loop; } Some(mut block_allocation) => { - tracing::debug!("Allocation: {block_allocation:?}"); + // tracing::debug!("Allocation: {block_allocation:?}"); max_blocks = max(max_blocks, block_allocation.blocks.len() as u32); - if block_allocation.prefix_len == entry.request.input_length { + if block_allocation.prefix_len >= entry.request.input_length { + // panic!("Something wrong happened we have overmatched the prefix {} >= {}", block_allocation.prefix_len, entry.request.input_length); // The whole request was found in the radix trie // However, for the transformer forward to work, we need to // have at least one token of postfix. @@ -336,6 +339,13 @@ impl State { } }; + let new_slots = &block_allocation.slots[block_allocation.prefix_len as usize..]; + for s in new_slots { + let entry = viewed.entry(*s).or_default(); + *entry += 1; + assert!(*entry <= 1); + } + let postfix_len = entry.request.input_length - block_allocation.prefix_len; if prefill_tokens + postfix_len > prefill_token_budget { @@ -349,6 +359,14 @@ impl State { } else { // We cannot prefill even one token for this entry // Add it back to the queue + + // Removing the allocations. + tracing::debug!("Removing some allocations"); + for s in new_slots { + let entry = viewed.entry(*s).or_default(); + *entry -= 1; + assert!(*entry <= 1); + } self.entries.push_front((id, entry)); } tracing::debug!( @@ -363,6 +381,12 @@ impl State { "Over budget: prefill_tokens={} > {prefill_token_budget}", prefill_tokens + postfix_len ); + tracing::debug!("Removing some allocations"); + for s in new_slots { + let entry = viewed.entry(*s).or_default(); + *entry -= 1; + assert!(*entry <= 1); + } self.entries.push_front((id, entry)); break 'entry_loop; } diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 532ec6dd..c9675c22 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -1,5 +1,6 @@ use crate::block_allocator::{Allocator, BlockAllocation}; use slotmap::{DefaultKey, SlotMap}; +use std::collections::HashSet; use std::hash::{Hash, Hasher}; use std::{ collections::{BTreeSet, HashMap}, @@ -86,9 +87,11 @@ impl Allocator for RadixAllocator { ) -> Option { let mut blocks = vec![]; let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() { - let node_id = self - .cache_blocks - .find(prefill_tokens.as_slice(), &mut blocks); + let node_id = self.cache_blocks.find( + // &prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)], + &prefill_tokens.as_slice(), + &mut blocks, + ); node_id } else { self.cache_blocks.root_id() @@ -136,6 +139,26 @@ impl Allocator for RadixAllocator { slots }; + tracing::debug!("Allocated {}", self.allocation_id); + let slot_set = slots.iter().collect::>(); + let mut slot_count: HashMap = HashMap::new(); + for slot in &slots { + let entry = slot_count.entry(*slot).or_default(); + *entry += 1; + } + let duplicates: HashMap = + slot_count.into_iter().filter(|(_k, v)| *v > 1).collect(); + // assert_eq!(slots.len(), slot_set.len(), "Duplicates {duplicates:?}"); + + let free_set = self.free_blocks.iter().collect::>(); + assert_eq!( + free_set + .intersection(&slot_set) + .collect::>() + .len(), + 0 + ); + let allocation = RadixAllocation { prefix_node, cached_prefix_len: prefix_len, @@ -144,6 +167,7 @@ impl Allocator for RadixAllocator { self.allocation_id += 1; self.allocations.insert(self.allocation_id, allocation); + tracing::debug!("Allocated {}", self.allocation_id); Some(BlockAllocation { allocation_id: self.allocation_id, @@ -155,7 +179,8 @@ impl Allocator for RadixAllocator { } fn free(&mut self, blocks: Vec, allocation_id: u64) { - let allocation = match self.allocations.remove(&allocation_id) { + tracing::debug!("Radix free {allocation_id}"); + let allocation: RadixAllocation = match self.allocations.remove(&allocation_id) { Some(allocation) => allocation, None => unreachable!("Tried to free an unknown allocation."), }; @@ -283,7 +308,7 @@ impl RadixTrie { } /// Find worker. - fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { + fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec) -> NodeId { let node = &self.nodes[node_id]; if key.len() >= self.block_size { @@ -295,9 +320,13 @@ impl RadixTrie { assert_eq!(shared_prefix_len % self.block_size, 0); blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); + // A node represents the prefix of its children. So, only + // recurse when there is a full prefix match. let key = &key[shared_prefix_len..]; - if !key.is_empty() { - node_id = self.find_(child_id, key, blocks); + if !key.is_empty() && shared_prefix_len == child.key.len() { + return self.find_(child_id, key, blocks); + } else { + return child_id; } } } @@ -369,7 +398,6 @@ impl RadixTrie { while let Some((last_access, node_id)) = self.leaves.pop_first() { let blocks_needed = n_blocks.saturating_sub(evicted.len()); - tracing::debug!("Evicting node {node_id:?} "); let node = self.nodes.get(node_id).expect("Leave does not exist"); assert_eq!( @@ -381,11 +409,8 @@ impl RadixTrie { if blocks_needed >= node.blocks.len() { // We need to evict the whole node if we need more blocks than it has. let node = self.remove_node(node_id); + tracing::debug!("Evicted node {node_id:?} got back {}", node.blocks.len()); evicted.extend(node.blocks); - - if evicted.len() >= n_blocks { - break; - } } else { // The node has more blocks than needed, so we'll just remove // the required number of blocks and leave the remaining blocks @@ -397,6 +422,10 @@ impl RadixTrie { node.key.truncate(truncate_tokens); evicted.extend(node.blocks.split_off(truncate_blocks)); self.leaves.insert((last_access, node_id)); + tracing::debug!("Evicted partial node {node_id:?} got {blocks_needed} back",); + } + if evicted.len() >= n_blocks { + tracing::debug!("Got enough {}", evicted.len()); break; } } @@ -873,4 +902,76 @@ mod tests { // Clear out the whole trie. assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]); } + + #[derive(Clone, Debug)] + enum Command { + Allocate { + tokens: u32, + prefill: Option>>, + }, + Free { + blocks: Vec, + allocation_id: u64, + }, + } + + #[derive(Clone)] + struct Vocab(u32); + + impl Arbitrary for Vocab { + fn arbitrary(gen: &mut Gen) -> Self { + let free = bool::arbitrary(gen); + if free { + Vocab(0) + } else { + Vocab(1) + } + } + } + + use quickcheck::quickcheck; + use quickcheck::{Arbitrary, Gen}; + + impl Arbitrary for Command { + fn arbitrary(gen: &mut Gen) -> Self { + let free = bool::arbitrary(gen); + if free { + let blocks: Vec = Vec::arbitrary(gen); + let allocation_id = u64::arbitrary(gen); + Command::Free { + blocks, + allocation_id, + } + } else { + let tokens = u32::arbitrary(gen); + let prefill_tokens: Vec = Vec::arbitrary(gen); + let prefill_tokens = prefill_tokens.into_iter().map(|v| v.0).collect(); + let prefill = Some(Arc::new(prefill_tokens)); + Command::Allocate { tokens, prefill } + } + } + } + + quickcheck! { + fn allocator_commands(commands: Vec) -> bool { + let mut cache = RadixAllocator::new(1, 20, None); + let mut allocations = vec![]; + for command in commands{ + match command{ + Command::Allocate{tokens, prefill} => { + let allocation = cache.allocate(tokens, prefill); + if let Some(allocation) = allocation{ + allocations.push(allocation.allocation_id); + } + } + Command::Free{blocks, allocation_id} => { + if allocations.contains(&allocation_id){ + cache.free(blocks, allocation_id); + } + } + } + } + true + } + } } diff --git a/router/src/server.rs b/router/src/server.rs index 9e57af27..cfc73aef 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -505,7 +505,7 @@ async fn generate_stream_internal( let start_time = Instant::now(); metrics::counter!("tgi_request_count").increment(1); - tracing::debug!("Input: {}", req.inputs); + // tracing::debug!("Input: {}", req.inputs); let compute_characters = req.inputs.chars().count(); diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1073f4f9..c5140392 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -4,6 +4,7 @@ import os import time import torch import torch.distributed +from collections import Counter import numpy as np @@ -87,6 +88,71 @@ tracer = trace.get_tracer(__name__) SLIDING_WINDOW: Optional[int] = None +WARMUP = True + + +def ASSERT_SIMPLE(batch): + slots = [] + for r in batch.requests: + slots.extend(r.slots[r.cache_len :]) + assert len(set(slots)) == len(slots) + + +def ASSERT_BATCH_IS_CORRECT(batch): + global WARMUP + input_ids = batch.input_ids + position_ids = batch.position_ids + cu_seqlen_prefill = batch.cu_seqlen_prefill + # kv_cache = self.kv_cache + block_tables = batch.block_tables_tensor + slots = batch.slots[batch.slot_indices] + input_lengths_tensor = batch.input_lengths_tensor + cache_lengths_tensor = batch.cache_lengths_tensor + max_s = batch.max_current_length + lm_head_indices = batch.prefill_head_indices + assert input_ids.shape == position_ids.shape + # print(input_lengths_tensor, cache_lengths_tensor, slots, block_tables) + assert input_lengths_tensor.shape == cache_lengths_tensor.shape + assert torch.all(cache_lengths_tensor >= 0) + assert torch.all(input_lengths_tensor > 0) + + loffset = 0 + coffset = 0 + assert torch.unique(slots).shape == slots.shape, ( + f"Slots {slots} - Cache {cache_lengths_tensor} Input {input_lengths_tensor} - Slto indices {batch.slot_indices} - Counter {Counter(slots.tolist()).most_common(3)} " + ) + + previous_slots = [] + previous_blocks = [] + for input_length, cache_length in zip(input_lengths_tensor, cache_lengths_tensor): + slot = slots[loffset : loffset + input_length] + blocks = block_tables[coffset][: input_length + cache_length] + assert len(slot.shape) == 1 + # print(f"Blocks {blocks} - Slots {slots}") + assert torch.all(blocks[cache_length : cache_length + input_length] == slot) + if not WARMUP: + assert torch.all(blocks != 0) + assert torch.unique(blocks).shape == blocks.shape + + for pblocks in previous_blocks: + m = min(pblocks.shape[0], blocks.shape[0]) + diff = pblocks[:m] - blocks[:m] + NZ = diff.nonzero().view(-1) + if NZ.shape[0]: + # Remove the first offset + assert NZ[0] + NZ.shape[0] == m + NZ = NZ - NZ[0] + assert torch.all(NZ >= 0), f"{pblocks} - blocks {blocks} NZ {NZ}" + assert torch.all(NZ == torch.arange(NZ.shape[0], device=NZ.device)) + + loffset += input_length + coffset += 1 + previous_slots.append(slot) + previous_blocks.append(blocks) + # assert cu_seqlen_prefill.shape == position_ids.shape + WARMUP = False + + def small_power_of_2(n: int): return 1 << ((n - 1).bit_length() - 1) @@ -135,6 +201,11 @@ def init_cpu_threads_env(rank_id: int, world_size: int): ) +from collections import defaultdict + +HISTORY = defaultdict(list) + + @dataclass class FlashCausalLMBatch(Batch): batch_id: int @@ -262,6 +333,7 @@ class FlashCausalLMBatch(Batch): dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": + HISTORY[pb.id].append(("TOKENIZED")) speculate = get_speculate() cache_lengths = [] @@ -290,6 +362,7 @@ class FlashCausalLMBatch(Batch): block_tables_ragged = [] # Parse batch + viewed = set() for i, (r, tokenized_input) in enumerate( zip(pb.requests, batch_tokenized_inputs) ): @@ -304,10 +377,16 @@ class FlashCausalLMBatch(Batch): prompt_lengths.append(prompt_length) cache_length = r.cache_len + new_slots = r.slots[cache_length:] + if set(new_slots).intersection(viewed): + import ipdb - assert ( - cache_length <= prompt_length - ), f"Prefix {cache_length} vs input {prompt_length}" + ipdb.set_trace() + viewed.update(set(new_slots)) + + assert cache_length <= prompt_length, ( + f"Prefix {cache_length} vs input {prompt_length}" + ) if cache_length == prompt_length: assert False, "unreachable" @@ -325,9 +404,9 @@ class FlashCausalLMBatch(Batch): postfix_ids = tokenized_input[ cache_length : cache_length + input_length ] - assert ( - len(postfix_ids) == input_length - ), "Rust and Python tokenizers are not aligned" + assert len(postfix_ids) == input_length, ( + "Rust and Python tokenizers are not aligned" + ) else: # Use all the remaining ids postfix_ids = tokenized_input[cache_length:] @@ -378,6 +457,7 @@ class FlashCausalLMBatch(Batch): cu_blocks.append(len(block_tables_ragged)) slots.extend(request_slots) + cu_slots.append(len(slots)) cache_lengths.append(cache_length) @@ -392,6 +472,25 @@ class FlashCausalLMBatch(Batch): prompt_length + max_new_tokens + speculative_length, ) + # offset = 0 + # new_slots = [] + # total_slots = [] + # for cache_length, input_length in zip(cache_lengths, input_lengths): + # new_slots_ = slots[ + # offset + cache_length : offset + cache_length + input_length + # ] + # offset += cache_length + input_length + # new_slots.extend(new_slots_) + # total_slots.append(new_slots_) + # if new_slots: + # if Counter(new_slots).most_common(1)[0][1] != 1: + # import ipdb + + # ipdb.set_trace() + # assert Counter(new_slots).most_common(1)[0][1] == 1, ( + # f"New slots {new_slots}" + # ) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype, device, tokenizer ) @@ -496,6 +595,7 @@ class FlashCausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": + HISTORY[self.batch_id].append(("FILTER", request_ids)) if len(request_ids) == 0: raise ValueError("Batch must have at least one request") # We assume that if len(requests) == len(self) then the requests are the same @@ -702,6 +802,7 @@ class FlashCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": + HISTORY[batches[0].batch_id].append(("CONCATENATE", batches)) # Batch attributes requests = [] requests_idx_mapping = {} @@ -884,6 +985,15 @@ class FlashCausalLMBatch(Batch): cumulative_slots += len(batch.slots) cumulative_batch_size += len(batch) + if slot_indices: + new_slots = slots[slot_indices] + import ipdb + + ipdb.set_trace() + assert torch.unique(new_slots).shape == new_slots.shape, ( + f"Slots {new_slots} - Cache {cache_lengths_tensor} Input {input_lengths_tensor} - Slto indices {slot_indices} - Counter {Counter(new_slots.tolist()).most_common(3)} " + ) + next_token_chooser = HeterogeneousNextTokenChooser.from_pb( next_token_chooser_parameters, dtype=batches[0].next_token_chooser.dtype, @@ -991,7 +1101,17 @@ class FlashCausalLMBatch(Batch): cu_slots_gpu, self.position_ids, self.slot_indices, + self.slots, ) + SLOTS = self.slots[self.slot_indices] + most_common = Counter(SLOTS.view(-1).tolist()).most_common(3) + if torch.unique(SLOTS.view(-1)).shape != SLOTS.view(-1).shape: + import ipdb + + ipdb.set_trace() + assert torch.unique(SLOTS.view(-1)).shape == SLOTS.view(-1).shape, ( + f"Slots {self.slots.view(-1)} Indices {self.slot_indices} - COUNTER {most_common} - Diff {self.slots == most_common[0][0]}" + ) sliding_window = get_sliding_windows() position_ids = [] @@ -1813,7 +1933,7 @@ class FlashCausalLM(Model): kv_cache = self.kv_cache block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] - input_lengths = batch.input_lengths_tensor + input_lengths_tensor = batch.input_lengths_tensor cache_lengths_tensor = batch.cache_lengths_tensor max_s = batch.max_current_length lm_head_indices = batch.prefill_head_indices @@ -1832,6 +1952,8 @@ class FlashCausalLM(Model): else: cuda_graph = None + ASSERT_BATCH_IS_CORRECT(batch) + if cu_seqlen_prefill is not None or cuda_graph is None: if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( @@ -1845,11 +1967,11 @@ class FlashCausalLM(Model): with self._forward_context( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, - input_lengths_tensor=input_lengths, + input_lengths_tensor=input_lengths_tensor, cache_lengths_tensor=cache_lengths_tensor, ): seqlen = Seqlen( - input_lengths=input_lengths, + input_lengths=input_lengths_tensor, cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, max_q=batch.max_input_length, @@ -1897,11 +2019,13 @@ class FlashCausalLM(Model): cuda_graph["slots"].fill_(0) cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() - cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths + cuda_graph["input_lengths"][: input_lengths_tensor.shape[0]] = ( + input_lengths_tensor + ) cuda_graph["cache_lengths"].zero_() - cuda_graph["cache_lengths"][ - : cache_lengths_tensor.shape[0] - ] = cache_lengths_tensor + cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = ( + cache_lengths_tensor + ) with self._forward_context( block_tables=cuda_graph["block_tables"], diff --git a/server/text_generation_server/models/metadata_kernels.py b/server/text_generation_server/models/metadata_kernels.py index 42b77121..b9d89b1f 100644 --- a/server/text_generation_server/models/metadata_kernels.py +++ b/server/text_generation_server/models/metadata_kernels.py @@ -2,6 +2,7 @@ import torch import triton import triton.language as tl +from collections import Counter from loguru import logger from typing import List, Optional @@ -130,6 +131,7 @@ def prepare_position_slot_ids( cu_slots: torch.Tensor, position_ids: torch.Tensor, slot_indices: torch.Tensor, + slots: torch.Tensor, ): def grid(meta): return ( @@ -140,6 +142,12 @@ def prepare_position_slot_ids( triton_prepare_position_slot_ids[grid]( cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256 ) + SLOTS = slots[slot_indices] + most_common = Counter(SLOTS.view(-1).tolist()).most_common(3) + if torch.unique(SLOTS.view(-1)).shape != SLOTS.view(-1).shape: + import ipdb + + ipdb.set_trace() def slots_filtering( @@ -158,6 +166,10 @@ def slots_filtering( triton_slots_filtering[grid]( slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256 ) + assert torch.all(slots[slots_start] == filtered_slots[cu_slots[:-1]]) + # assert torch.unique(slots).shape == slots.shape, ( + # f"Slots {slots} {Counter(slots.tolist()).most_common(3)}" + # ) @triton.jit diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 935e0985..5269f522 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -11,6 +11,10 @@ from grpc_reflection.v1alpha import reflection from pathlib import Path from typing import List, Optional +from text_generation_server.models.flash_causal_lm import ( + ASSERT_BATCH_IS_CORRECT, + ASSERT_SIMPLE, +) from text_generation_server.cache import Cache from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.models import Model, get_model_with_lora_adapters @@ -164,6 +168,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): self.model.device, ) else: + ASSERT_SIMPLE(request.batch) batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) @@ -178,6 +183,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ) start_concat = time.time_ns() batch = self.model.batch_type.concatenate([cached_batch, batch]) + # ASSERT_BATCH_IS_CORRECT(batch) concat_ns = time.time_ns() - start_concat generations, next_batch, timings = self.model.generate_token(batch)