mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
All the assertions.
Invariants added Remove the logs.
This commit is contained in:
parent
463228ebfc
commit
ddf0b02240
22
Cargo.lock
generated
22
Cargo.lock
generated
@ -1169,6 +1169,16 @@ dependencies = [
|
|||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "env_logger"
|
||||||
|
version = "0.8.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3"
|
||||||
|
dependencies = [
|
||||||
|
"log",
|
||||||
|
"regex",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "equivalent"
|
name = "equivalent"
|
||||||
version = "1.0.1"
|
version = "1.0.1"
|
||||||
@ -3403,6 +3413,17 @@ version = "2.0.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
|
checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quickcheck"
|
||||||
|
version = "1.0.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6"
|
||||||
|
dependencies = [
|
||||||
|
"env_logger",
|
||||||
|
"log",
|
||||||
|
"rand",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "quote"
|
name = "quote"
|
||||||
version = "1.0.37"
|
version = "1.0.37"
|
||||||
@ -4630,6 +4651,7 @@ dependencies = [
|
|||||||
"opentelemetry-otlp",
|
"opentelemetry-otlp",
|
||||||
"prost 0.12.6",
|
"prost 0.12.6",
|
||||||
"prost-build",
|
"prost-build",
|
||||||
|
"quickcheck",
|
||||||
"rand",
|
"rand",
|
||||||
"regex",
|
"regex",
|
||||||
"reqwest 0.11.27",
|
"reqwest 0.11.27",
|
||||||
|
@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
|
|||||||
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
|
||||||
use pb::generate::v3::*;
|
use pb::generate::v3::*;
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
use std::io::Error;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tonic::transport::{Channel, Uri};
|
use tonic::transport::{Channel, Uri};
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
@ -232,6 +233,20 @@ impl Client {
|
|||||||
batch: Batch,
|
batch: Batch,
|
||||||
cached_batch: Option<CachedBatch>,
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let slots: Vec<_> = batch
|
||||||
|
.requests
|
||||||
|
.iter()
|
||||||
|
.map(|r| &r.slots[r.cache_len as usize..])
|
||||||
|
.flatten()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
slots.len(),
|
||||||
|
slots.iter().collect::<std::collections::HashSet<_>>().len()
|
||||||
|
);
|
||||||
|
if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
let request = tonic::Request::new(PrefillRequest {
|
let request = tonic::Request::new(PrefillRequest {
|
||||||
batch: Some(batch),
|
batch: Some(batch),
|
||||||
cached_batch,
|
cached_batch,
|
||||||
|
@ -63,6 +63,7 @@ base64 = { workspace = true }
|
|||||||
prost = "^0.12"
|
prost = "^0.12"
|
||||||
tonic = "^0.10"
|
tonic = "^0.10"
|
||||||
tower = "^0.4"
|
tower = "^0.4"
|
||||||
|
quickcheck = "1.0.3"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
tonic-build = "0.10.1"
|
tonic-build = "0.10.1"
|
||||||
|
@ -5,6 +5,7 @@ use crate::client::{
|
|||||||
use crate::queue::{Entry, Queue};
|
use crate::queue::{Entry, Queue};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use nohash_hasher::IntMap;
|
use nohash_hasher::IntMap;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::ValidGenerateRequest;
|
use text_generation_router::validation::ValidGenerateRequest;
|
||||||
@ -13,6 +14,7 @@ use tokio::sync::mpsc::error::SendError;
|
|||||||
use tokio::sync::{mpsc, Notify};
|
use tokio::sync::{mpsc, Notify};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tonic::IntoRequest;
|
||||||
use tracing::{info_span, instrument, Instrument, Span};
|
use tracing::{info_span, instrument, Instrument, Span};
|
||||||
|
|
||||||
pub struct BackendV3 {
|
pub struct BackendV3 {
|
||||||
@ -121,6 +123,35 @@ impl Backend for BackendV3 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Batch {
|
||||||
|
pub fn check(&self) -> Result<(), InferError> {
|
||||||
|
let slots: Vec<_> = self
|
||||||
|
.requests
|
||||||
|
.iter()
|
||||||
|
.map(|r| &r.slots[r.cache_len as usize..])
|
||||||
|
.flatten()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// assert_eq!(
|
||||||
|
// slots.len(),
|
||||||
|
// slots.iter().collect::<std::collections::HashSet<_>>().len()
|
||||||
|
// );
|
||||||
|
if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
|
||||||
|
let mut map: HashMap<u32, usize> = HashMap::new();
|
||||||
|
for slot in slots {
|
||||||
|
*map.entry(*slot).or_default() += 1usize;
|
||||||
|
}
|
||||||
|
let duplicates: HashMap<_, _> = map.into_iter().filter(|(_slot, c)| *c > 1).collect();
|
||||||
|
|
||||||
|
Err(InferError::GenerationError(format!(
|
||||||
|
"Invalid batch: {duplicates:?}",
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Batching logic
|
/// Batching logic
|
||||||
/// Will be launched in a background Tokio task
|
/// Will be launched in a background Tokio task
|
||||||
///
|
///
|
||||||
@ -154,6 +185,7 @@ pub(crate) async fn batching_task(
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
batch.check().unwrap();
|
||||||
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
|
let mut cached_batch = prefill(&mut client, batch, None, &mut entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
.await;
|
.await;
|
||||||
@ -205,6 +237,7 @@ pub(crate) async fn batching_task(
|
|||||||
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
.next_batch(min_size, max_size, prefill_token_budget, token_budget)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
new_batch.check().unwrap();
|
||||||
// Tracking metrics
|
// Tracking metrics
|
||||||
if min_size.is_some() {
|
if min_size.is_some() {
|
||||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||||
@ -225,6 +258,7 @@ pub(crate) async fn batching_task(
|
|||||||
// concatenated during the prefill op server side
|
// concatenated during the prefill op server side
|
||||||
entries.extend(new_entries);
|
entries.extend(new_entries);
|
||||||
// Generate one token for both the cached batch and the new batch
|
// Generate one token for both the cached batch and the new batch
|
||||||
|
new_batch.check().unwrap();
|
||||||
let new_cached_batch =
|
let new_cached_batch =
|
||||||
prefill(&mut client, new_batch, cached_batch, &mut entries)
|
prefill(&mut client, new_batch, cached_batch, &mut entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
@ -249,6 +283,7 @@ pub(crate) async fn batching_task(
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Generate one token for this new batch to have the attention past in cache
|
// Generate one token for this new batch to have the attention past in cache
|
||||||
|
new_batch.check().unwrap();
|
||||||
let new_cached_batch =
|
let new_cached_batch =
|
||||||
prefill(&mut client, new_batch, None, &mut new_entries)
|
prefill(&mut client, new_batch, None, &mut new_entries)
|
||||||
.instrument(span)
|
.instrument(span)
|
||||||
|
@ -19,7 +19,13 @@ pub struct BlockAllocation {
|
|||||||
impl Drop for BlockAllocation {
|
impl Drop for BlockAllocation {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
if let Some(block_allocator) = self.block_allocator.as_mut() {
|
if let Some(block_allocator) = self.block_allocator.as_mut() {
|
||||||
|
tracing::debug!("Freeing block {}", self.allocation_id);
|
||||||
block_allocator.free(self.blocks.clone(), self.allocation_id)
|
block_allocator.free(self.blocks.clone(), self.allocation_id)
|
||||||
|
} else {
|
||||||
|
#[cfg(not(test))]
|
||||||
|
{
|
||||||
|
panic!("We didn't have a block allocator");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
/// Single shard Client
|
/// Single shard Client
|
||||||
use crate::client::{pb, Chunk};
|
use crate::client::{pb, Chunk};
|
||||||
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
use crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};
|
||||||
|
use axum::http::Error;
|
||||||
use base64::engine::general_purpose::STANDARD;
|
use base64::engine::general_purpose::STANDARD;
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use grpc_metadata::InjectTelemetryContext;
|
use grpc_metadata::InjectTelemetryContext;
|
||||||
@ -232,11 +233,28 @@ impl Client {
|
|||||||
batch: Batch,
|
batch: Batch,
|
||||||
cached_batch: Option<CachedBatch>,
|
cached_batch: Option<CachedBatch>,
|
||||||
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
|
||||||
|
let slots: Vec<_> = batch
|
||||||
|
.requests
|
||||||
|
.iter()
|
||||||
|
.map(|r| &r.slots[r.cache_len as usize..])
|
||||||
|
.flatten()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
slots.len(),
|
||||||
|
slots.iter().collect::<std::collections::HashSet<_>>().len()
|
||||||
|
);
|
||||||
|
if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
let request = tonic::Request::new(PrefillRequest {
|
let request = tonic::Request::new(PrefillRequest {
|
||||||
batch: Some(batch),
|
batch: Some(batch),
|
||||||
cached_batch,
|
cached_batch,
|
||||||
})
|
})
|
||||||
.inject_context();
|
.inject_context();
|
||||||
|
// if slots.len() != slots.iter().collect::<std::collections::HashSet<_>>().len() {
|
||||||
|
// return Err(Error::from("Test"));
|
||||||
|
// }
|
||||||
let response = self.stub.prefill(request).await?.into_inner();
|
let response = self.stub.prefill(request).await?.into_inner();
|
||||||
Ok((
|
Ok((
|
||||||
response.generations,
|
response.generations,
|
||||||
|
@ -5,7 +5,7 @@ use crate::client::{
|
|||||||
};
|
};
|
||||||
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
use nohash_hasher::{BuildNoHashHasher, IntMap};
|
||||||
use std::cmp::max;
|
use std::cmp::max;
|
||||||
use std::collections::VecDeque;
|
use std::collections::{HashMap, VecDeque};
|
||||||
use text_generation_router::infer::InferError;
|
use text_generation_router::infer::InferError;
|
||||||
use text_generation_router::infer::InferStreamResponse;
|
use text_generation_router::infer::InferStreamResponse;
|
||||||
use text_generation_router::validation::{
|
use text_generation_router::validation::{
|
||||||
@ -269,6 +269,8 @@ impl State {
|
|||||||
let mut decode_tokens: u32 = 0;
|
let mut decode_tokens: u32 = 0;
|
||||||
let mut max_blocks = 0;
|
let mut max_blocks = 0;
|
||||||
|
|
||||||
|
let mut viewed: HashMap<u32, usize> = HashMap::new();
|
||||||
|
|
||||||
// Pop entries starting from the front of the queue
|
// Pop entries starting from the front of the queue
|
||||||
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
|
'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {
|
||||||
// Filter entries where the response receiver was dropped (== entries where the request
|
// Filter entries where the response receiver was dropped (== entries where the request
|
||||||
@ -311,7 +313,7 @@ impl State {
|
|||||||
+ entry.request.stopping_parameters.max_new_tokens
|
+ entry.request.stopping_parameters.max_new_tokens
|
||||||
+ self.speculate
|
+ self.speculate
|
||||||
- 1;
|
- 1;
|
||||||
tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
// tracing::debug!("Allocating {tokens} with {input_ids:?}");
|
||||||
|
|
||||||
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
|
let block_allocation = match block_allocator.allocate(tokens, input_ids).await {
|
||||||
None => {
|
None => {
|
||||||
@ -322,10 +324,11 @@ impl State {
|
|||||||
break 'entry_loop;
|
break 'entry_loop;
|
||||||
}
|
}
|
||||||
Some(mut block_allocation) => {
|
Some(mut block_allocation) => {
|
||||||
tracing::debug!("Allocation: {block_allocation:?}");
|
// tracing::debug!("Allocation: {block_allocation:?}");
|
||||||
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);
|
||||||
|
|
||||||
if block_allocation.prefix_len == entry.request.input_length {
|
if block_allocation.prefix_len >= entry.request.input_length {
|
||||||
|
// panic!("Something wrong happened we have overmatched the prefix {} >= {}", block_allocation.prefix_len, entry.request.input_length);
|
||||||
// The whole request was found in the radix trie
|
// The whole request was found in the radix trie
|
||||||
// However, for the transformer forward to work, we need to
|
// However, for the transformer forward to work, we need to
|
||||||
// have at least one token of postfix.
|
// have at least one token of postfix.
|
||||||
@ -336,6 +339,13 @@ impl State {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let new_slots = &block_allocation.slots[block_allocation.prefix_len as usize..];
|
||||||
|
for s in new_slots {
|
||||||
|
let entry = viewed.entry(*s).or_default();
|
||||||
|
*entry += 1;
|
||||||
|
assert!(*entry <= 1);
|
||||||
|
}
|
||||||
|
|
||||||
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
|
let postfix_len = entry.request.input_length - block_allocation.prefix_len;
|
||||||
|
|
||||||
if prefill_tokens + postfix_len > prefill_token_budget {
|
if prefill_tokens + postfix_len > prefill_token_budget {
|
||||||
@ -349,6 +359,14 @@ impl State {
|
|||||||
} else {
|
} else {
|
||||||
// We cannot prefill even one token for this entry
|
// We cannot prefill even one token for this entry
|
||||||
// Add it back to the queue
|
// Add it back to the queue
|
||||||
|
|
||||||
|
// Removing the allocations.
|
||||||
|
tracing::debug!("Removing some allocations");
|
||||||
|
for s in new_slots {
|
||||||
|
let entry = viewed.entry(*s).or_default();
|
||||||
|
*entry -= 1;
|
||||||
|
assert!(*entry <= 1);
|
||||||
|
}
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
}
|
}
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
@ -363,6 +381,12 @@ impl State {
|
|||||||
"Over budget: prefill_tokens={} > {prefill_token_budget}",
|
"Over budget: prefill_tokens={} > {prefill_token_budget}",
|
||||||
prefill_tokens + postfix_len
|
prefill_tokens + postfix_len
|
||||||
);
|
);
|
||||||
|
tracing::debug!("Removing some allocations");
|
||||||
|
for s in new_slots {
|
||||||
|
let entry = viewed.entry(*s).or_default();
|
||||||
|
*entry -= 1;
|
||||||
|
assert!(*entry <= 1);
|
||||||
|
}
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
break 'entry_loop;
|
break 'entry_loop;
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use crate::block_allocator::{Allocator, BlockAllocation};
|
use crate::block_allocator::{Allocator, BlockAllocation};
|
||||||
use slotmap::{DefaultKey, SlotMap};
|
use slotmap::{DefaultKey, SlotMap};
|
||||||
|
use std::collections::HashSet;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeSet, HashMap},
|
collections::{BTreeSet, HashMap},
|
||||||
@ -86,9 +87,11 @@ impl Allocator for RadixAllocator {
|
|||||||
) -> Option<BlockAllocation> {
|
) -> Option<BlockAllocation> {
|
||||||
let mut blocks = vec![];
|
let mut blocks = vec![];
|
||||||
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {
|
||||||
let node_id = self
|
let node_id = self.cache_blocks.find(
|
||||||
.cache_blocks
|
// &prefill_tokens.as_slice()[..prefill_tokens.len().saturating_sub(1)],
|
||||||
.find(prefill_tokens.as_slice(), &mut blocks);
|
&prefill_tokens.as_slice(),
|
||||||
|
&mut blocks,
|
||||||
|
);
|
||||||
node_id
|
node_id
|
||||||
} else {
|
} else {
|
||||||
self.cache_blocks.root_id()
|
self.cache_blocks.root_id()
|
||||||
@ -136,6 +139,26 @@ impl Allocator for RadixAllocator {
|
|||||||
slots
|
slots
|
||||||
};
|
};
|
||||||
|
|
||||||
|
tracing::debug!("Allocated {}", self.allocation_id);
|
||||||
|
let slot_set = slots.iter().collect::<HashSet<_>>();
|
||||||
|
let mut slot_count: HashMap<u32, usize> = HashMap::new();
|
||||||
|
for slot in &slots {
|
||||||
|
let entry = slot_count.entry(*slot).or_default();
|
||||||
|
*entry += 1;
|
||||||
|
}
|
||||||
|
let duplicates: HashMap<u32, usize> =
|
||||||
|
slot_count.into_iter().filter(|(_k, v)| *v > 1).collect();
|
||||||
|
// assert_eq!(slots.len(), slot_set.len(), "Duplicates {duplicates:?}");
|
||||||
|
|
||||||
|
let free_set = self.free_blocks.iter().collect::<HashSet<_>>();
|
||||||
|
assert_eq!(
|
||||||
|
free_set
|
||||||
|
.intersection(&slot_set)
|
||||||
|
.collect::<HashSet<_>>()
|
||||||
|
.len(),
|
||||||
|
0
|
||||||
|
);
|
||||||
|
|
||||||
let allocation = RadixAllocation {
|
let allocation = RadixAllocation {
|
||||||
prefix_node,
|
prefix_node,
|
||||||
cached_prefix_len: prefix_len,
|
cached_prefix_len: prefix_len,
|
||||||
@ -144,6 +167,7 @@ impl Allocator for RadixAllocator {
|
|||||||
|
|
||||||
self.allocation_id += 1;
|
self.allocation_id += 1;
|
||||||
self.allocations.insert(self.allocation_id, allocation);
|
self.allocations.insert(self.allocation_id, allocation);
|
||||||
|
tracing::debug!("Allocated {}", self.allocation_id);
|
||||||
|
|
||||||
Some(BlockAllocation {
|
Some(BlockAllocation {
|
||||||
allocation_id: self.allocation_id,
|
allocation_id: self.allocation_id,
|
||||||
@ -155,7 +179,8 @@ impl Allocator for RadixAllocator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {
|
||||||
let allocation = match self.allocations.remove(&allocation_id) {
|
tracing::debug!("Radix free {allocation_id}");
|
||||||
|
let allocation: RadixAllocation = match self.allocations.remove(&allocation_id) {
|
||||||
Some(allocation) => allocation,
|
Some(allocation) => allocation,
|
||||||
None => unreachable!("Tried to free an unknown allocation."),
|
None => unreachable!("Tried to free an unknown allocation."),
|
||||||
};
|
};
|
||||||
@ -283,7 +308,7 @@ impl RadixTrie {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Find worker.
|
/// Find worker.
|
||||||
fn find_(&mut self, mut node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {
|
||||||
let node = &self.nodes[node_id];
|
let node = &self.nodes[node_id];
|
||||||
|
|
||||||
if key.len() >= self.block_size {
|
if key.len() >= self.block_size {
|
||||||
@ -295,9 +320,13 @@ impl RadixTrie {
|
|||||||
assert_eq!(shared_prefix_len % self.block_size, 0);
|
assert_eq!(shared_prefix_len % self.block_size, 0);
|
||||||
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);
|
||||||
|
|
||||||
|
// A node represents the prefix of its children. So, only
|
||||||
|
// recurse when there is a full prefix match.
|
||||||
let key = &key[shared_prefix_len..];
|
let key = &key[shared_prefix_len..];
|
||||||
if !key.is_empty() {
|
if !key.is_empty() && shared_prefix_len == child.key.len() {
|
||||||
node_id = self.find_(child_id, key, blocks);
|
return self.find_(child_id, key, blocks);
|
||||||
|
} else {
|
||||||
|
return child_id;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -369,7 +398,6 @@ impl RadixTrie {
|
|||||||
|
|
||||||
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
while let Some((last_access, node_id)) = self.leaves.pop_first() {
|
||||||
let blocks_needed = n_blocks.saturating_sub(evicted.len());
|
let blocks_needed = n_blocks.saturating_sub(evicted.len());
|
||||||
tracing::debug!("Evicting node {node_id:?} ");
|
|
||||||
|
|
||||||
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
let node = self.nodes.get(node_id).expect("Leave does not exist");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -381,11 +409,8 @@ impl RadixTrie {
|
|||||||
if blocks_needed >= node.blocks.len() {
|
if blocks_needed >= node.blocks.len() {
|
||||||
// We need to evict the whole node if we need more blocks than it has.
|
// We need to evict the whole node if we need more blocks than it has.
|
||||||
let node = self.remove_node(node_id);
|
let node = self.remove_node(node_id);
|
||||||
|
tracing::debug!("Evicted node {node_id:?} got back {}", node.blocks.len());
|
||||||
evicted.extend(node.blocks);
|
evicted.extend(node.blocks);
|
||||||
|
|
||||||
if evicted.len() >= n_blocks {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// The node has more blocks than needed, so we'll just remove
|
// The node has more blocks than needed, so we'll just remove
|
||||||
// the required number of blocks and leave the remaining blocks
|
// the required number of blocks and leave the remaining blocks
|
||||||
@ -397,6 +422,10 @@ impl RadixTrie {
|
|||||||
node.key.truncate(truncate_tokens);
|
node.key.truncate(truncate_tokens);
|
||||||
evicted.extend(node.blocks.split_off(truncate_blocks));
|
evicted.extend(node.blocks.split_off(truncate_blocks));
|
||||||
self.leaves.insert((last_access, node_id));
|
self.leaves.insert((last_access, node_id));
|
||||||
|
tracing::debug!("Evicted partial node {node_id:?} got {blocks_needed} back",);
|
||||||
|
}
|
||||||
|
if evicted.len() >= n_blocks {
|
||||||
|
tracing::debug!("Got enough {}", evicted.len());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -873,4 +902,76 @@ mod tests {
|
|||||||
// Clear out the whole trie.
|
// Clear out the whole trie.
|
||||||
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
enum Command {
|
||||||
|
Allocate {
|
||||||
|
tokens: u32,
|
||||||
|
prefill: Option<Arc<Vec<u32>>>,
|
||||||
|
},
|
||||||
|
Free {
|
||||||
|
blocks: Vec<u32>,
|
||||||
|
allocation_id: u64,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct Vocab(u32);
|
||||||
|
|
||||||
|
impl Arbitrary for Vocab {
|
||||||
|
fn arbitrary(gen: &mut Gen) -> Self {
|
||||||
|
let free = bool::arbitrary(gen);
|
||||||
|
if free {
|
||||||
|
Vocab(0)
|
||||||
|
} else {
|
||||||
|
Vocab(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use quickcheck::quickcheck;
|
||||||
|
use quickcheck::{Arbitrary, Gen};
|
||||||
|
|
||||||
|
impl Arbitrary for Command {
|
||||||
|
fn arbitrary(gen: &mut Gen) -> Self {
|
||||||
|
let free = bool::arbitrary(gen);
|
||||||
|
if free {
|
||||||
|
let blocks: Vec<u32> = Vec::arbitrary(gen);
|
||||||
|
let allocation_id = u64::arbitrary(gen);
|
||||||
|
Command::Free {
|
||||||
|
blocks,
|
||||||
|
allocation_id,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let tokens = u32::arbitrary(gen);
|
||||||
|
let prefill_tokens: Vec<Vocab> = Vec::arbitrary(gen);
|
||||||
|
let prefill_tokens = prefill_tokens.into_iter().map(|v| v.0).collect();
|
||||||
|
let prefill = Some(Arc::new(prefill_tokens));
|
||||||
|
Command::Allocate { tokens, prefill }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
quickcheck! {
|
||||||
|
fn allocator_commands(commands: Vec<Command>) -> bool {
|
||||||
|
let mut cache = RadixAllocator::new(1, 20, None);
|
||||||
|
let mut allocations = vec![];
|
||||||
|
for command in commands{
|
||||||
|
match command{
|
||||||
|
Command::Allocate{tokens, prefill} => {
|
||||||
|
let allocation = cache.allocate(tokens, prefill);
|
||||||
|
if let Some(allocation) = allocation{
|
||||||
|
allocations.push(allocation.allocation_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Command::Free{blocks, allocation_id} => {
|
||||||
|
if allocations.contains(&allocation_id){
|
||||||
|
cache.free(blocks, allocation_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -505,7 +505,7 @@ async fn generate_stream_internal(
|
|||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::counter!("tgi_request_count").increment(1);
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
tracing::debug!("Input: {}", req.inputs);
|
// tracing::debug!("Input: {}", req.inputs);
|
||||||
|
|
||||||
let compute_characters = req.inputs.chars().count();
|
let compute_characters = req.inputs.chars().count();
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -87,6 +88,71 @@ tracer = trace.get_tracer(__name__)
|
|||||||
SLIDING_WINDOW: Optional[int] = None
|
SLIDING_WINDOW: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
WARMUP = True
|
||||||
|
|
||||||
|
|
||||||
|
def ASSERT_SIMPLE(batch):
|
||||||
|
slots = []
|
||||||
|
for r in batch.requests:
|
||||||
|
slots.extend(r.slots[r.cache_len :])
|
||||||
|
assert len(set(slots)) == len(slots)
|
||||||
|
|
||||||
|
|
||||||
|
def ASSERT_BATCH_IS_CORRECT(batch):
|
||||||
|
global WARMUP
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
# kv_cache = self.kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths_tensor = batch.input_lengths_tensor
|
||||||
|
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||||
|
max_s = batch.max_current_length
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
assert input_ids.shape == position_ids.shape
|
||||||
|
# print(input_lengths_tensor, cache_lengths_tensor, slots, block_tables)
|
||||||
|
assert input_lengths_tensor.shape == cache_lengths_tensor.shape
|
||||||
|
assert torch.all(cache_lengths_tensor >= 0)
|
||||||
|
assert torch.all(input_lengths_tensor > 0)
|
||||||
|
|
||||||
|
loffset = 0
|
||||||
|
coffset = 0
|
||||||
|
assert torch.unique(slots).shape == slots.shape, (
|
||||||
|
f"Slots {slots} - Cache {cache_lengths_tensor} Input {input_lengths_tensor} - Slto indices {batch.slot_indices} - Counter {Counter(slots.tolist()).most_common(3)} "
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_slots = []
|
||||||
|
previous_blocks = []
|
||||||
|
for input_length, cache_length in zip(input_lengths_tensor, cache_lengths_tensor):
|
||||||
|
slot = slots[loffset : loffset + input_length]
|
||||||
|
blocks = block_tables[coffset][: input_length + cache_length]
|
||||||
|
assert len(slot.shape) == 1
|
||||||
|
# print(f"Blocks {blocks} - Slots {slots}")
|
||||||
|
assert torch.all(blocks[cache_length : cache_length + input_length] == slot)
|
||||||
|
if not WARMUP:
|
||||||
|
assert torch.all(blocks != 0)
|
||||||
|
assert torch.unique(blocks).shape == blocks.shape
|
||||||
|
|
||||||
|
for pblocks in previous_blocks:
|
||||||
|
m = min(pblocks.shape[0], blocks.shape[0])
|
||||||
|
diff = pblocks[:m] - blocks[:m]
|
||||||
|
NZ = diff.nonzero().view(-1)
|
||||||
|
if NZ.shape[0]:
|
||||||
|
# Remove the first offset
|
||||||
|
assert NZ[0] + NZ.shape[0] == m
|
||||||
|
NZ = NZ - NZ[0]
|
||||||
|
assert torch.all(NZ >= 0), f"{pblocks} - blocks {blocks} NZ {NZ}"
|
||||||
|
assert torch.all(NZ == torch.arange(NZ.shape[0], device=NZ.device))
|
||||||
|
|
||||||
|
loffset += input_length
|
||||||
|
coffset += 1
|
||||||
|
previous_slots.append(slot)
|
||||||
|
previous_blocks.append(blocks)
|
||||||
|
# assert cu_seqlen_prefill.shape == position_ids.shape
|
||||||
|
WARMUP = False
|
||||||
|
|
||||||
|
|
||||||
def small_power_of_2(n: int):
|
def small_power_of_2(n: int):
|
||||||
return 1 << ((n - 1).bit_length() - 1)
|
return 1 << ((n - 1).bit_length() - 1)
|
||||||
|
|
||||||
@ -135,6 +201,11 @@ def init_cpu_threads_env(rank_id: int, world_size: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
HISTORY = defaultdict(list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashCausalLMBatch(Batch):
|
class FlashCausalLMBatch(Batch):
|
||||||
batch_id: int
|
batch_id: int
|
||||||
@ -262,6 +333,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
|
HISTORY[pb.id].append(("TOKENIZED"))
|
||||||
speculate = get_speculate()
|
speculate = get_speculate()
|
||||||
|
|
||||||
cache_lengths = []
|
cache_lengths = []
|
||||||
@ -290,6 +362,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_ragged = []
|
block_tables_ragged = []
|
||||||
|
|
||||||
# Parse batch
|
# Parse batch
|
||||||
|
viewed = set()
|
||||||
for i, (r, tokenized_input) in enumerate(
|
for i, (r, tokenized_input) in enumerate(
|
||||||
zip(pb.requests, batch_tokenized_inputs)
|
zip(pb.requests, batch_tokenized_inputs)
|
||||||
):
|
):
|
||||||
@ -304,10 +377,16 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prompt_lengths.append(prompt_length)
|
prompt_lengths.append(prompt_length)
|
||||||
|
|
||||||
cache_length = r.cache_len
|
cache_length = r.cache_len
|
||||||
|
new_slots = r.slots[cache_length:]
|
||||||
|
if set(new_slots).intersection(viewed):
|
||||||
|
import ipdb
|
||||||
|
|
||||||
assert (
|
ipdb.set_trace()
|
||||||
cache_length <= prompt_length
|
viewed.update(set(new_slots))
|
||||||
), f"Prefix {cache_length} vs input {prompt_length}"
|
|
||||||
|
assert cache_length <= prompt_length, (
|
||||||
|
f"Prefix {cache_length} vs input {prompt_length}"
|
||||||
|
)
|
||||||
if cache_length == prompt_length:
|
if cache_length == prompt_length:
|
||||||
assert False, "unreachable"
|
assert False, "unreachable"
|
||||||
|
|
||||||
@ -325,9 +404,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
postfix_ids = tokenized_input[
|
postfix_ids = tokenized_input[
|
||||||
cache_length : cache_length + input_length
|
cache_length : cache_length + input_length
|
||||||
]
|
]
|
||||||
assert (
|
assert len(postfix_ids) == input_length, (
|
||||||
len(postfix_ids) == input_length
|
"Rust and Python tokenizers are not aligned"
|
||||||
), "Rust and Python tokenizers are not aligned"
|
)
|
||||||
else:
|
else:
|
||||||
# Use all the remaining ids
|
# Use all the remaining ids
|
||||||
postfix_ids = tokenized_input[cache_length:]
|
postfix_ids = tokenized_input[cache_length:]
|
||||||
@ -378,6 +457,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cu_blocks.append(len(block_tables_ragged))
|
cu_blocks.append(len(block_tables_ragged))
|
||||||
|
|
||||||
slots.extend(request_slots)
|
slots.extend(request_slots)
|
||||||
|
|
||||||
cu_slots.append(len(slots))
|
cu_slots.append(len(slots))
|
||||||
|
|
||||||
cache_lengths.append(cache_length)
|
cache_lengths.append(cache_length)
|
||||||
@ -392,6 +472,25 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prompt_length + max_new_tokens + speculative_length,
|
prompt_length + max_new_tokens + speculative_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# offset = 0
|
||||||
|
# new_slots = []
|
||||||
|
# total_slots = []
|
||||||
|
# for cache_length, input_length in zip(cache_lengths, input_lengths):
|
||||||
|
# new_slots_ = slots[
|
||||||
|
# offset + cache_length : offset + cache_length + input_length
|
||||||
|
# ]
|
||||||
|
# offset += cache_length + input_length
|
||||||
|
# new_slots.extend(new_slots_)
|
||||||
|
# total_slots.append(new_slots_)
|
||||||
|
# if new_slots:
|
||||||
|
# if Counter(new_slots).most_common(1)[0][1] != 1:
|
||||||
|
# import ipdb
|
||||||
|
|
||||||
|
# ipdb.set_trace()
|
||||||
|
# assert Counter(new_slots).most_common(1)[0][1] == 1, (
|
||||||
|
# f"New slots {new_slots}"
|
||||||
|
# )
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters, dtype, device, tokenizer
|
next_token_chooser_parameters, dtype, device, tokenizer
|
||||||
)
|
)
|
||||||
@ -496,6 +595,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
|
||||||
|
HISTORY[self.batch_id].append(("FILTER", request_ids))
|
||||||
if len(request_ids) == 0:
|
if len(request_ids) == 0:
|
||||||
raise ValueError("Batch must have at least one request")
|
raise ValueError("Batch must have at least one request")
|
||||||
# We assume that if len(requests) == len(self) then the requests are the same
|
# We assume that if len(requests) == len(self) then the requests are the same
|
||||||
@ -702,6 +802,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
|
||||||
|
HISTORY[batches[0].batch_id].append(("CONCATENATE", batches))
|
||||||
# Batch attributes
|
# Batch attributes
|
||||||
requests = []
|
requests = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
@ -884,6 +985,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cumulative_slots += len(batch.slots)
|
cumulative_slots += len(batch.slots)
|
||||||
cumulative_batch_size += len(batch)
|
cumulative_batch_size += len(batch)
|
||||||
|
|
||||||
|
if slot_indices:
|
||||||
|
new_slots = slots[slot_indices]
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
|
assert torch.unique(new_slots).shape == new_slots.shape, (
|
||||||
|
f"Slots {new_slots} - Cache {cache_lengths_tensor} Input {input_lengths_tensor} - Slto indices {slot_indices} - Counter {Counter(new_slots.tolist()).most_common(3)} "
|
||||||
|
)
|
||||||
|
|
||||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||||
next_token_chooser_parameters,
|
next_token_chooser_parameters,
|
||||||
dtype=batches[0].next_token_chooser.dtype,
|
dtype=batches[0].next_token_chooser.dtype,
|
||||||
@ -991,7 +1101,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cu_slots_gpu,
|
cu_slots_gpu,
|
||||||
self.position_ids,
|
self.position_ids,
|
||||||
self.slot_indices,
|
self.slot_indices,
|
||||||
|
self.slots,
|
||||||
)
|
)
|
||||||
|
SLOTS = self.slots[self.slot_indices]
|
||||||
|
most_common = Counter(SLOTS.view(-1).tolist()).most_common(3)
|
||||||
|
if torch.unique(SLOTS.view(-1)).shape != SLOTS.view(-1).shape:
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
|
assert torch.unique(SLOTS.view(-1)).shape == SLOTS.view(-1).shape, (
|
||||||
|
f"Slots {self.slots.view(-1)} Indices {self.slot_indices} - COUNTER {most_common} - Diff {self.slots == most_common[0][0]}"
|
||||||
|
)
|
||||||
|
|
||||||
sliding_window = get_sliding_windows()
|
sliding_window = get_sliding_windows()
|
||||||
position_ids = []
|
position_ids = []
|
||||||
@ -1813,7 +1933,7 @@ class FlashCausalLM(Model):
|
|||||||
kv_cache = self.kv_cache
|
kv_cache = self.kv_cache
|
||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths_tensor = batch.input_lengths_tensor
|
||||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||||
max_s = batch.max_current_length
|
max_s = batch.max_current_length
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
@ -1832,6 +1952,8 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
|
ASSERT_BATCH_IS_CORRECT(batch)
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
@ -1845,11 +1967,11 @@ class FlashCausalLM(Model):
|
|||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
input_lengths_tensor=input_lengths,
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
cache_lengths_tensor=cache_lengths_tensor,
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
):
|
):
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths_tensor,
|
||||||
cache_lengths=cache_lengths_tensor,
|
cache_lengths=cache_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
max_q=batch.max_input_length,
|
max_q=batch.max_input_length,
|
||||||
@ -1897,11 +2019,13 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph["slots"].fill_(0)
|
cuda_graph["slots"].fill_(0)
|
||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
cuda_graph["input_lengths"][: input_lengths_tensor.shape[0]] = (
|
||||||
|
input_lengths_tensor
|
||||||
|
)
|
||||||
cuda_graph["cache_lengths"].zero_()
|
cuda_graph["cache_lengths"].zero_()
|
||||||
cuda_graph["cache_lengths"][
|
cuda_graph["cache_lengths"][: cache_lengths_tensor.shape[0]] = (
|
||||||
: cache_lengths_tensor.shape[0]
|
cache_lengths_tensor
|
||||||
] = cache_lengths_tensor
|
)
|
||||||
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=cuda_graph["block_tables"],
|
block_tables=cuda_graph["block_tables"],
|
||||||
|
@ -2,6 +2,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
|
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
@ -130,6 +131,7 @@ def prepare_position_slot_ids(
|
|||||||
cu_slots: torch.Tensor,
|
cu_slots: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
slot_indices: torch.Tensor,
|
slot_indices: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
):
|
):
|
||||||
def grid(meta):
|
def grid(meta):
|
||||||
return (
|
return (
|
||||||
@ -140,6 +142,12 @@ def prepare_position_slot_ids(
|
|||||||
triton_prepare_position_slot_ids[grid](
|
triton_prepare_position_slot_ids[grid](
|
||||||
cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256
|
cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256
|
||||||
)
|
)
|
||||||
|
SLOTS = slots[slot_indices]
|
||||||
|
most_common = Counter(SLOTS.view(-1).tolist()).most_common(3)
|
||||||
|
if torch.unique(SLOTS.view(-1)).shape != SLOTS.view(-1).shape:
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
|
|
||||||
|
|
||||||
def slots_filtering(
|
def slots_filtering(
|
||||||
@ -158,6 +166,10 @@ def slots_filtering(
|
|||||||
triton_slots_filtering[grid](
|
triton_slots_filtering[grid](
|
||||||
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
|
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
|
||||||
)
|
)
|
||||||
|
assert torch.all(slots[slots_start] == filtered_slots[cu_slots[:-1]])
|
||||||
|
# assert torch.unique(slots).shape == slots.shape, (
|
||||||
|
# f"Slots {slots} {Counter(slots.tolist()).most_common(3)}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
@ -11,6 +11,10 @@ from grpc_reflection.v1alpha import reflection
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
|
ASSERT_BATCH_IS_CORRECT,
|
||||||
|
ASSERT_SIMPLE,
|
||||||
|
)
|
||||||
from text_generation_server.cache import Cache
|
from text_generation_server.cache import Cache
|
||||||
from text_generation_server.interceptor import ExceptionInterceptor
|
from text_generation_server.interceptor import ExceptionInterceptor
|
||||||
from text_generation_server.models import Model, get_model_with_lora_adapters
|
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||||
@ -164,6 +168,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
self.model.device,
|
self.model.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
ASSERT_SIMPLE(request.batch)
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||||
)
|
)
|
||||||
@ -178,6 +183,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||||||
)
|
)
|
||||||
start_concat = time.time_ns()
|
start_concat = time.time_ns()
|
||||||
batch = self.model.batch_type.concatenate([cached_batch, batch])
|
batch = self.model.batch_type.concatenate([cached_batch, batch])
|
||||||
|
# ASSERT_BATCH_IS_CORRECT(batch)
|
||||||
concat_ns = time.time_ns() - start_concat
|
concat_ns = time.time_ns() - start_concat
|
||||||
|
|
||||||
generations, next_batch, timings = self.model.generate_token(batch)
|
generations, next_batch, timings = self.model.generate_token(batch)
|
||||||
|
Loading…
Reference in New Issue
Block a user