mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
All the assertions.
Invariants added Remove the logs.
This commit is contained in:
parent
463228ebfc
commit
ddf0b02240
22
Cargo.lock
generated
22
Cargo.lock
generated
@ -1169,6 +1169,16 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
|
||||
dependencies = [
|
||||
"log",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
@ -3403,6 +3413,17 @@ version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
|
||||
|
||||
[[package]]
|
||||
name = "quickcheck"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6"
|
||||
dependencies = [
|
||||
"env_logger",
|
||||
"log",
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.37"
|
||||
@ -4630,6 +4651,7 @@ dependencies = [
|
||||
"opentelemetry-otlp",
|
||||
"prost 0.12.6",
|
||||
"prost-build",
|
||||
"quickcheck",
|
||||
"rand",
|
||||
"regex",
|
||||
"reqwest 0.11.27",
|
||||
|
@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
|
||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||
use pb::generate::v3::*;
|
||||
use std::cmp::min;
|
||||
use std::io::Error;
|
||||
use std::time::Duration;
|
||||
use tonic::transport::{Channel, Uri};
|
||||
use tracing::instrument;
|
||||
@ -232,6 +233,20 @@ impl Client {
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let slots: Vec<_> = batch
|
||||
.requests
|
||||
.iter()
|
||||
.map(|r| &r.slots[r.cache_len as usize..])
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
slots.len(),
|
||||
slots.iter().collect::<std::collections::HashSet<_>>().len()
|
||||
);
|
||||
if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
|
||||
std::process::exit(1);
|
||||
}
|
||||
let request = tonic::Request::new(PrefillRequest {
|
||||
batch: Some(batch),
|
||||
cached_batch,
|
||||
|
@ -63,6 +63,7 @@ base64 = { workspace = true }
|
||||
prost = "^0.12"
|
||||
tonic = "^0.10"
|
||||
tower = "^0.4"
|
||||
quickcheck = "1.0.3"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.10.1"
|
||||
|
@ -5,6 +5,7 @@ use crate::client::{
|
||||
use crate::queue::{Entry, Queue};
|
||||
use async_trait::async_trait;
|
||||
use nohash_hasher::IntMap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::ValidGenerateRequest;
|
||||
@ -13,6 +14,7 @@ use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::{mpsc, Notify};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tonic::IntoRequest;
|
||||
use tracing::{info_span, instrument, Instrument, Span};
|
||||
|
||||
pub struct BackendV3 {
|
||||
@ -121,6 +123,35 @@ impl Backend for BackendV3 {
|
||||
}
|
||||
}
|
||||
|
||||
impl Batch {
|
||||
pub fn check(&self) -> Result<(), InferError> {
|
||||
let slots: Vec<_> = self
|
||||
.requests
|
||||
.iter()
|
||||
.map(|r| &r.slots[r.cache_len as usize..])
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
// assert_eq!(
|
||||
// slots.len(),
|
||||
// slots.iter().collect::<std::collections::HashSet<_>>().len()
|
||||
// );
|
||||
if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
|
||||
let mut map: HashMap<u32, usize> = HashMap::new();
|
||||
for slot in slots {
|
||||
*map.entry(*slot).or_default() += 1usize;
|
||||
}
|
||||
let duplicates: HashMap<_, _> = map.into_iter().filter(|(_slot, c)| *c > 1).collect();
|
||||
|
||||
Err(InferError::GenerationError(format!(
|
||||
"Invalid batch: {duplicates:?}",
|
||||
)))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Batching logic
|
||||
/// Will be launched in a background Tokio task
|
||||
///
|
||||
@ -154,6 +185,7 @@ pub(crate) async fn batching_task(
|
||||
)
|
||||
.await
|
||||
{
|
||||
batch.check().unwrap();
|
||||
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
|
||||
.instrument(span)
|
||||
.await;
|
||||
@ -205,6 +237,7 @@ pub(crate) async fn batching_task(
|
||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||
.await
|
||||
{
|
||||
new_batch.check().unwrap();
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
@ -225,6 +258,7 @@ pub(crate) async fn batching_task(
|
||||
// concatenated during the prefill op server side
|
||||
entries.extend(new_entries);
|
||||
// Generate one token for both the cached batch and the new batch
|
||||
new_batch.check().unwrap();
|
||||
let new_cached_batch =
|
||||
prefill(&mut client, new_batch, cached_batch, &mut entries)
|
||||
.instrument(span)
|
||||
@ -249,6 +283,7 @@ pub(crate) async fn batching_task(
|
||||
});
|
||||
|
||||
// Generate one token for this new batch to have the attention past in cache
|
||||
new_batch.check().unwrap();
|
||||
let new_cached_batch =
|
||||
prefill(&mut client, new_batch, None, &mut new_entries)
|
||||
.instrument(span)
|
||||
|
@ -19,7 +19,13 @@ pub struct BlockAllocation {
|
||||
impl Drop for BlockAllocation {
|
||||
fn drop(&mut self) {
|
||||
if let Some(block_allocator) = self.block_allocator.as_mut() {
|
||||
tracing::debug!("Freeing block {}", self.allocation_id);
|
||||
block_allocator.free(self.blocks.clone(), self.allocation_id)
|
||||
} else {
|
||||
#[cfg(not(test))]
|
||||
{
|
||||
panic!("We didn't have a block allocator");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
/// Single shard Client
|
||||
use crate::client::{pb, Chunk};
|
||||
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||
use axum::http::Error;
|
||||
use base64::engine::general_purpose::STANDARD;
|
||||
use base64::Engine;
|
||||
use grpc_metadata::InjectTelemetryContext;
|
||||
@ -232,11 +233,28 @@ impl Client {
|
||||
batch: Batch,
|
||||
cached_batch: Option<CachedBatch>,
|
||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||
let slots: Vec<_> = batch
|
||||
.requests
|
||||
.iter()
|
||||
.map(|r| &r.slots[r.cache_len as usize..])
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
slots.len(),
|
||||
slots.iter().collect::<std::collections::HashSet<_>>().len()
|
||||
);
|
||||
if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
|
||||
std::process::exit(1);
|
||||
}
|
||||
let request = tonic::Request::new(PrefillRequest {
|
||||
batch: Some(batch),
|
||||
cached_batch,
|
||||
})
|
||||
.inject_context();
|
||||
// if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
|
||||
// return Err(Error::from("Test"));
|
||||
// }
|
||||
let response = self.stub.prefill(request).await?.into_inner();
|
||||
Ok((
|
||||
response.generations,
|
||||
|
@ -5,7 +5,7 @@ use crate::client::{
|
||||
};
|
||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||
use std::cmp::max;
|
||||
use std::collections::VecDeque;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use text_generation_router::infer::InferError;
|
||||
use text_generation_router::infer::InferStreamResponse;
|
||||
use text_generation_router::validation::{
|
||||
@ -269,6 +269,8 @@ impl State {
|
||||
let mut decode_tokens: u32 = 0;
|
||||
let mut max_blocks = 0;
|
||||
|
||||
let mut viewed: HashMap<u32, usize> = HashMap::new();
|
||||
|
||||
// Pop entries starting from the front of the queue
|
||||
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
|
||||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
@ -311,7 +313,7 @@ impl State {
|
||||
+ entry.request.stopping_parameters.max_new_tokens
|
||||
+ self.speculate
|
||||
- 1;
|
||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||
// tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||
|
||||
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
|
||||
None => {
|
||||
@ -322,10 +324,11 @@ impl State {
|
||||
break 'entry_loop;
|
||||
}
|
||||
Some(mut block_allocation) => {
|
||||
tracing::debug!("Allocation: {block_allocation:?}");
|
||||
// tracing::debug!("Allocation: {block_allocation:?}");
|
||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||
|
||||
if block_allocation.prefix_len == entry.request.input_length {
|
||||
if block_allocation.prefix_len >= entry.request.input_length {
|
||||
// panic!("Something wrong happened we have overmatched the prefix {} >= {}", block_allocation.prefix_len, entry.request.input_length);
|
||||
// The whole request was found in the radix trie
|
||||
// However, for the transformer forward to work, we need to
|
||||
// have at least one token of postfix.
|
||||
@ -336,6 +339,13 @@ impl State {
|
||||
}
|
||||
};
|
||||
|
||||
let new_slots = &block_allocation.slots[block_allocation.prefix_len as usize..];
|
||||
for s in new_slots {
|
||||
let entry = viewed.entry(*s).or_default();
|
||||
*entry += 1;
|
||||
assert!(*entry <= 1);
|
||||
}
|
||||
|
||||
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
|
||||
|
||||
if prefill_tokens + postfix_len > prefill_token_budget {
|
||||
@ -349,6 +359,14 @@ impl State {
|
||||
} else {
|
||||
// We cannot prefill even one token for this entry
|
||||
// Add it back to the queue
|
||||
|
||||
// Removing the allocations.
|
||||
tracing::debug!("Removing some allocations");
|
||||
for s in new_slots {
|
||||
let entry = viewed.entry(*s).or_default();
|
||||
*entry -= 1;
|
||||
assert!(*entry <= 1);
|
||||
}
|
||||
self.entries.push_front((id, entry));
|
||||
}
|
||||
tracing::debug!(
|
||||
@ -363,6 +381,12 @@ impl State {
|
||||
"Over budget: prefill_tokens={} > {prefill_token_budget}",
|
||||
prefill_tokens + postfix_len
|
||||
);
|
||||
tracing::debug!("Removing some allocations");
|
||||
for s in new_slots {
|
||||
let entry = viewed.entry(*s).or_default();
|
||||
*entry -= 1;
|
||||
assert!(*entry <= 1);
|
||||
}
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||
use slotmap::{DefaultKey, SlotMap};
|
||||
use std::collections::HashSet;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::{
|
||||
collections::{BTreeSet, HashMap},
|
||||
@ -86,9 +87,11 @@ impl Allocator for RadixAllocator {
|
||||
) -> Option<BlockAllocation> {
|
||||
let mut blocks = vec![];
|
||||
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||
let node_id = self
|
||||
.cache_blocks
|
||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
||||
let node_id = self.cache_blocks.find(
|
||||
// &prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)],
|
||||
&prefill_tokens.as_slice(),
|
||||
&mut blocks,
|
||||
);
|
||||
node_id
|
||||
} else {
|
||||
self.cache_blocks.root_id()
|
||||
@ -136,6 +139,26 @@ impl Allocator for RadixAllocator {
|
||||
slots
|
||||
};
|
||||
|
||||
tracing::debug!("Allocated {}", self.allocation_id);
|
||||
let slot_set = slots.iter().collect::<HashSet<_>>();
|
||||
let mut slot_count: HashMap<u32, usize> = HashMap::new();
|
||||
for slot in &slots {
|
||||
let entry = slot_count.entry(*slot).or_default();
|
||||
*entry += 1;
|
||||
}
|
||||
let duplicates: HashMap<u32, usize> =
|
||||
slot_count.into_iter().filter(|(_k, v)| *v > 1).collect();
|
||||
// assert_eq!(slots.len(), slot_set.len(), "Duplicates {duplicates:?}");
|
||||
|
||||
let free_set = self.free_blocks.iter().collect::<HashSet<_>>();
|
||||
assert_eq!(
|
||||
free_set
|
||||
.intersection(&slot_set)
|
||||
.collect::<HashSet<_>>()
|
||||
.len(),
|
||||
0
|
||||
);
|
||||
|
||||
let allocation = RadixAllocation {
|
||||
prefix_node,
|
||||
cached_prefix_len: prefix_len,
|
||||
@ -144,6 +167,7 @@ impl Allocator for RadixAllocator {
|
||||
|
||||
self.allocation_id += 1;
|
||||
self.allocations.insert(self.allocation_id, allocation);
|
||||
tracing::debug!("Allocated {}", self.allocation_id);
|
||||
|
||||
Some(BlockAllocation {
|
||||
allocation_id: self.allocation_id,
|
||||
@ -155,7 +179,8 @@ impl Allocator for RadixAllocator {
|
||||
}
|
||||
|
||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||
let allocation = match self.allocations.remove(&allocation_id) {
|
||||
tracing::debug!("Radix free {allocation_id}");
|
||||
let allocation: RadixAllocation = match self.allocations.remove(&allocation_id) {
|
||||
Some(allocation) => allocation,
|
||||
None => unreachable!("Tried to free an unknown allocation."),
|
||||
};
|
||||
@ -283,7 +308,7 @@ impl RadixTrie {
|
||||
}
|
||||
|
||||
/// Find worker.
|
||||
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||
fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||
let node = &self.nodes[node_id];
|
||||
|
||||
if key.len() >= self.block_size {
|
||||
@ -295,9 +320,13 @@ impl RadixTrie {
|
||||
assert_eq!(shared_prefix_len % self.block_size, 0);
|
||||
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
||||
|
||||
// A node represents the prefix of its children. So, only
|
||||
// recurse when there is a full prefix match.
|
||||
let key = &key[shared_prefix_len..];
|
||||
if !key.is_empty() {
|
||||
node_id = self.find_(child_id, key, blocks);
|
||||
if !key.is_empty() && shared_prefix_len == child.key.len() {
|
||||
return self.find_(child_id, key, blocks);
|
||||
} else {
|
||||
return child_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -369,7 +398,6 @@ impl RadixTrie {
|
||||
|
||||
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
||||
let blocks_needed = n_blocks.saturating_sub(evicted.len());
|
||||
tracing::debug!("Evicting node {node_id:?} ");
|
||||
|
||||
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
||||
assert_eq!(
|
||||
@ -381,11 +409,8 @@ impl RadixTrie {
|
||||
if blocks_needed >= node.blocks.len() {
|
||||
// We need to evict the whole node if we need more blocks than it has.
|
||||
let node = self.remove_node(node_id);
|
||||
tracing::debug!("Evicted node {node_id:?} got back {}", node.blocks.len());
|
||||
evicted.extend(node.blocks);
|
||||
|
||||
if evicted.len() >= n_blocks {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// The node has more blocks than needed, so we'll just remove
|
||||
// the required number of blocks and leave the remaining blocks
|
||||
@ -397,6 +422,10 @@ impl RadixTrie {
|
||||
node.key.truncate(truncate_tokens);
|
||||
evicted.extend(node.blocks.split_off(truncate_blocks));
|
||||
self.leaves.insert((last_access, node_id));
|
||||
tracing::debug!("Evicted partial node {node_id:?} got {blocks_needed} back",);
|
||||
}
|
||||
if evicted.len() >= n_blocks {
|
||||
tracing::debug!("Got enough {}", evicted.len());
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -873,4 +902,76 @@ mod tests {
|
||||
// Clear out the whole trie.
|
||||
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum Command {
|
||||
Allocate {
|
||||
tokens: u32,
|
||||
prefill: Option<Arc<Vec<u32>>>,
|
||||
},
|
||||
Free {
|
||||
blocks: Vec<u32>,
|
||||
allocation_id: u64,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Vocab(u32);
|
||||
|
||||
impl Arbitrary for Vocab {
|
||||
fn arbitrary(gen: &mut Gen) -> Self {
|
||||
let free = bool::arbitrary(gen);
|
||||
if free {
|
||||
Vocab(0)
|
||||
} else {
|
||||
Vocab(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use quickcheck::quickcheck;
|
||||
use quickcheck::{Arbitrary, Gen};
|
||||
|
||||
impl Arbitrary for Command {
|
||||
fn arbitrary(gen: &mut Gen) -> Self {
|
||||
let free = bool::arbitrary(gen);
|
||||
if free {
|
||||
let blocks: Vec<u32> = Vec::arbitrary(gen);
|
||||
let allocation_id = u64::arbitrary(gen);
|
||||
Command::Free {
|
||||
blocks,
|
||||
allocation_id,
|
||||
}
|
||||
} else {
|
||||
let tokens = u32::arbitrary(gen);
|
||||
let prefill_tokens: Vec<Vocab> = Vec::arbitrary(gen);
|
||||
let prefill_tokens = prefill_tokens.into_iter().map(|v| v.0).collect();
|
||||
let prefill = Some(Arc::new(prefill_tokens));
|
||||
Command::Allocate { tokens, prefill }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
quickcheck! {
|
||||
fn allocator_commands(commands: Vec<Command>) -> bool {
|
||||
let mut cache = RadixAllocator::new(1, 20, None);
|
||||
let mut allocations = vec![];
|
||||
for command in commands{
|
||||
match command{
|
||||
Command::Allocate{tokens, prefill} => {
|
||||
let allocation = cache.allocate(tokens, prefill);
|
||||
if let Some(allocation) = allocation{
|
||||
allocations.push(allocation.allocation_id);
|
||||
}
|
||||
}
|
||||
Command::Free{blocks, allocation_id} => {
|
||||
if allocations.contains(&allocation_id){
|
||||
cache.free(blocks, allocation_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -505,7 +505,7 @@ async fn generate_stream_internal(
|
||||
let start_time = Instant::now();
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
tracing::debug!("Input: {}", req.inputs);
|
||||
// tracing::debug!("Input: {}", req.inputs);
|
||||
|
||||
let compute_characters = req.inputs.chars().count();
|
||||
|
||||
|
@ -4,6 +4,7 @@ import os
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed
|
||||
from collections import Counter
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -87,6 +88,71 @@ tracer = trace.get_tracer(__name__)
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
|
||||
|
||||
WARMUP = True
|
||||
|
||||
|
||||
def ASSERT_SIMPLE(batch):
|
||||
slots = []
|
||||
for r in batch.requests:
|
||||
slots.extend(r.slots[r.cache_len :])
|
||||
assert len(set(slots)) == len(slots)
|
||||
|
||||
|
||||
def ASSERT_BATCH_IS_CORRECT(batch):
|
||||
global WARMUP
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
# kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths_tensor = batch.input_lengths_tensor
|
||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
assert input_ids.shape == position_ids.shape
|
||||
# print(input_lengths_tensor, cache_lengths_tensor, slots, block_tables)
|
||||
assert input_lengths_tensor.shape == cache_lengths_tensor.shape
|
||||
assert torch.all(cache_lengths_tensor >= 0)
|
||||
assert torch.all(input_lengths_tensor > 0)
|
||||
|
||||
loffset = 0
|
||||
coffset = 0
|
||||
assert torch.unique(slots).shape == slots.shape, (
|
||||
f"Slots {slots} - Cache {cache_lengths_tensor} Input {input_lengths_tensor} - Slto indices {batch.slot_indices} - Counter {Counter(slots.tolist()).most_common(3)} "
|
||||
)
|
||||
|
||||
previous_slots = []
|
||||
previous_blocks = []
|
||||
for input_length, cache_length in zip(input_lengths_tensor, cache_lengths_tensor):
|
||||
slot = slots[loffset : loffset + input_length]
|
||||
blocks = block_tables[coffset][: input_length + cache_length]
|
||||
assert len(slot.shape) == 1
|
||||
# print(f"Blocks {blocks} - Slots {slots}")
|
||||
assert torch.all(blocks[cache_length : cache_length + input_length] == slot)
|
||||
if not WARMUP:
|
||||
assert torch.all(blocks != 0)
|
||||
assert torch.unique(blocks).shape == blocks.shape
|
||||
|
||||
for pblocks in previous_blocks:
|
||||
m = min(pblocks.shape[0], blocks.shape[0])
|
||||
diff = pblocks[:m] - blocks[:m]
|
||||
NZ = diff.nonzero().view(-1)
|
||||
if NZ.shape[0]:
|
||||
# Remove the first offset
|
||||
assert NZ[0] + NZ.shape[0] == m
|
||||
NZ = NZ - NZ[0]
|
||||
assert torch.all(NZ >= 0), f"{pblocks} - blocks {blocks} NZ {NZ}"
|
||||
assert torch.all(NZ == torch.arange(NZ.shape[0], device=NZ.device))
|
||||
|
||||
loffset += input_length
|
||||
coffset += 1
|
||||
previous_slots.append(slot)
|
||||
previous_blocks.append(blocks)
|
||||
# assert cu_seqlen_prefill.shape == position_ids.shape
|
||||
WARMUP = False
|
||||
|
||||
|
||||
def small_power_of_2(n: int):
|
||||
return 1 << ((n - 1).bit_length() - 1)
|
||||
|
||||
@ -135,6 +201,11 @@ def init_cpu_threads_env(rank_id: int, world_size: int):
|
||||
)
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
HISTORY = defaultdict(list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashCausalLMBatch(Batch):
|
||||
batch_id: int
|
||||
@ -262,6 +333,7 @@ class FlashCausalLMBatch(Batch):
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
HISTORY[pb.id].append(("TOKENIZED"))
|
||||
speculate = get_speculate()
|
||||
|
||||
cache_lengths = []
|
||||
@ -290,6 +362,7 @@ class FlashCausalLMBatch(Batch):
|
||||
block_tables_ragged = []
|
||||
|
||||
# Parse batch
|
||||
viewed = set()
|
||||
for i, (r, tokenized_input) in enumerate(
|
||||
zip(pb.requests, batch_tokenized_inputs)
|
||||
):
|
||||
@ -304,10 +377,16 @@ class FlashCausalLMBatch(Batch):
|
||||
prompt_lengths.append(prompt_length)
|
||||
|
||||
cache_length = r.cache_len
|
||||
new_slots = r.slots[cache_length:]
|
||||
if set(new_slots).intersection(viewed):
|
||||
import ipdb
|
||||
|
||||
assert (
|
||||
cache_length <= prompt_length
|
||||
), f"Prefix {cache_length} vs input {prompt_length}"
|
||||
ipdb.set_trace()
|
||||
viewed.update(set(new_slots))
|
||||
|
||||
assert cache_length <= prompt_length, (
|
||||
f"Prefix {cache_length} vs input {prompt_length}"
|
||||
)
|
||||
if cache_length == prompt_length:
|
||||
assert False, "unreachable"
|
||||
|
||||
@ -325,9 +404,9 @@ class FlashCausalLMBatch(Batch):
|
||||
postfix_ids = tokenized_input[
|
||||
cache_length : cache_length + input_length
|
||||
]
|
||||
assert (
|
||||
len(postfix_ids) == input_length
|
||||
), "Rust and Python tokenizers are not aligned"
|
||||
assert len(postfix_ids) == input_length, (
|
||||
"Rust and Python tokenizers are not aligned"
|
||||
)
|
||||
else:
|
||||
# Use all the remaining ids
|
||||
postfix_ids = tokenized_input[cache_length:]
|
||||
@ -378,6 +457,7 @@ class FlashCausalLMBatch(Batch):
|
||||
cu_blocks.append(len(block_tables_ragged))
|
||||
|
||||
slots.extend(request_slots)
|
||||
|
||||
cu_slots.append(len(slots))
|
||||
|
||||
cache_lengths.append(cache_length)
|
||||
@ -392,6 +472,25 @@ class FlashCausalLMBatch(Batch):
|
||||
prompt_length + max_new_tokens + speculative_length,
|
||||
)
|
||||
|
||||
# offset = 0
|
||||
# new_slots = []
|
||||
# total_slots = []
|
||||
# for cache_length, input_length in zip(cache_lengths, input_lengths):
|
||||
# new_slots_ = slots[
|
||||
# offset + cache_length : offset + cache_length + input_length
|
||||
# ]
|
||||
# offset += cache_length + input_length
|
||||
# new_slots.extend(new_slots_)
|
||||
# total_slots.append(new_slots_)
|
||||
# if new_slots:
|
||||
# if Counter(new_slots).most_common(1)[0][1] != 1:
|
||||
# import ipdb
|
||||
|
||||
# ipdb.set_trace()
|
||||
# assert Counter(new_slots).most_common(1)[0][1] == 1, (
|
||||
# f"New slots {new_slots}"
|
||||
# )
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device, tokenizer
|
||||
)
|
||||
@ -496,6 +595,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||
HISTORY[self.batch_id].append(("FILTER", request_ids))
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
# We assume that if len(requests) == len(self) then the requests are the same
|
||||
@ -702,6 +802,7 @@ class FlashCausalLMBatch(Batch):
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||
HISTORY[batches[0].batch_id].append(("CONCATENATE", batches))
|
||||
# Batch attributes
|
||||
requests = []
|
||||
requests_idx_mapping = {}
|
||||
@ -884,6 +985,15 @@ class FlashCausalLMBatch(Batch):
|
||||
cumulative_slots += len(batch.slots)
|
||||
cumulative_batch_size += len(batch)
|
||||
|
||||
if slot_indices:
|
||||
new_slots = slots[slot_indices]
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
assert torch.unique(new_slots).shape == new_slots.shape, (
|
||||
f"Slots {new_slots} - Cache {cache_lengths_tensor} Input {input_lengths_tensor} - Slto indices {slot_indices} - Counter {Counter(new_slots.tolist()).most_common(3)} "
|
||||
)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters,
|
||||
dtype=batches[0].next_token_chooser.dtype,
|
||||
@ -991,7 +1101,17 @@ class FlashCausalLMBatch(Batch):
|
||||
cu_slots_gpu,
|
||||
self.position_ids,
|
||||
self.slot_indices,
|
||||
self.slots,
|
||||
)
|
||||
SLOTS = self.slots[self.slot_indices]
|
||||
most_common = Counter(SLOTS.view(-1).tolist()).most_common(3)
|
||||
if torch.unique(SLOTS.view(-1)).shape != SLOTS.view(-1).shape:
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
assert torch.unique(SLOTS.view(-1)).shape == SLOTS.view(-1).shape, (
|
||||
f"Slots {self.slots.view(-1)} Indices {self.slot_indices} - COUNTER {most_common} - Diff {self.slots == most_common[0][0]}"
|
||||
)
|
||||
|
||||
sliding_window = get_sliding_windows()
|
||||
position_ids = []
|
||||
@ -1813,7 +1933,7 @@ class FlashCausalLM(Model):
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
input_lengths_tensor = batch.input_lengths_tensor
|
||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
@ -1832,6 +1952,8 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
cuda_graph = None
|
||||
|
||||
ASSERT_BATCH_IS_CORRECT(batch)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
@ -1845,11 +1967,11 @@ class FlashCausalLM(Model):
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths_tensor=input_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
input_lengths=input_lengths_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=batch.max_input_length,
|
||||
@ -1897,11 +2019,13 @@ class FlashCausalLM(Model):
|
||||
cuda_graph["slots"].fill_(0)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["input_lengths"][: input_lengths_tensor.shape[0]] = (
|
||||
input_lengths_tensor
|
||||
)
|
||||
cuda_graph["cache_lengths"].zero_()
|
||||
cuda_graph["cache_lengths"][
|
||||
: cache_lengths_tensor.shape[0]
|
||||
] = cache_lengths_tensor
|
||||
cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
|
||||
cache_lengths_tensor
|
||||
)
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
import triton
|
||||
|
||||
import triton.language as tl
|
||||
from collections import Counter
|
||||
|
||||
from loguru import logger
|
||||
from typing import List, Optional
|
||||
@ -130,6 +131,7 @@ def prepare_position_slot_ids(
|
||||
cu_slots: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
slot_indices: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
def grid(meta):
|
||||
return (
|
||||
@ -140,6 +142,12 @@ def prepare_position_slot_ids(
|
||||
triton_prepare_position_slot_ids[grid](
|
||||
cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256
|
||||
)
|
||||
SLOTS = slots[slot_indices]
|
||||
most_common = Counter(SLOTS.view(-1).tolist()).most_common(3)
|
||||
if torch.unique(SLOTS.view(-1)).shape != SLOTS.view(-1).shape:
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
|
||||
def slots_filtering(
|
||||
@ -158,6 +166,10 @@ def slots_filtering(
|
||||
triton_slots_filtering[grid](
|
||||
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
|
||||
)
|
||||
assert torch.all(slots[slots_start] == filtered_slots[cu_slots[:-1]])
|
||||
# assert torch.unique(slots).shape == slots.shape, (
|
||||
# f"Slots {slots} {Counter(slots.tolist()).most_common(3)}"
|
||||
# )
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
@ -11,6 +11,10 @@ from grpc_reflection.v1alpha import reflection
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
ASSERT_BATCH_IS_CORRECT,
|
||||
ASSERT_SIMPLE,
|
||||
)
|
||||
from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||
@ -164,6 +168,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
self.model.device,
|
||||
)
|
||||
else:
|
||||
ASSERT_SIMPLE(request.batch)
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
@ -178,6 +183,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
)
|
||||
start_concat = time.time_ns()
|
||||
batch = self.model.batch_type.concatenate([cached_batch, batch])
|
||||
# ASSERT_BATCH_IS_CORRECT(batch)
|
||||
concat_ns = time.time_ns() - start_concat
|
||||
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
|
Loading…
Reference in New Issue
Block a user