added padded blocks and logs everywhere

This commit is contained in:
OlivierDehaene 2024-06-18 12:18:05 +02:00
parent abe521204e
commit 7ed1044585
8 changed files with 73 additions and 25 deletions

View File

@ -497,7 +497,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
"Lowest: {:.2} {unit}", "Lowest: {:.2} {unit}",
data.iter() data.iter()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN) .unwrap_or(&f64::NAN)
), ),
Style::default().fg(Color::Reset), Style::default().fg(Color::Reset),
)]), )]),
@ -506,7 +506,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
"Highest: {:.2} {unit}", "Highest: {:.2} {unit}",
data.iter() data.iter()
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN) .unwrap_or(&f64::NAN)
), ),
Style::default().fg(Color::Reset), Style::default().fg(Color::Reset),
)]), )]),
@ -555,17 +555,17 @@ fn latency_throughput_chart<'a>(
let min_latency: f64 = *latency_iter let min_latency: f64 = *latency_iter
.clone() .clone()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
let max_latency: f64 = *latency_iter let max_latency: f64 = *latency_iter
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
let min_throughput: f64 = *throughput_iter let min_throughput: f64 = *throughput_iter
.clone() .clone()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
let max_throughput: f64 = *throughput_iter let max_throughput: f64 = *throughput_iter
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
// Char min max values // Char min max values
let min_x = if zoom { let min_x = if zoom {

View File

@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) {
let min = data let min = data
.iter() .iter()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
let max = data let max = data
.iter() .iter()
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
(average, *min, *max) (average, *min, *max)
} }
fn px(data: &[f64], p: u32) -> f64 { fn px(data: &[f64], p: u32) -> f64 {
let i = (f64::from(p) / 100.0 * data.len() as f64) as usize; let i = (f64::from(p) / 100.0 * data.len() as f64) as usize;
*data.get(i).unwrap_or(&std::f64::NAN) *data.get(i).unwrap_or(&f64::NAN)
} }
fn format_value(value: f64, unit: &'static str) -> String { fn format_value(value: f64, unit: &'static str) -> String {

View File

@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f
.iter() .iter()
.map(|&p| { .map(|&p| {
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize; let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
(format!("p{p}"), *values.get(i).unwrap_or(&std::f64::NAN)) (format!("p{p}"), *values.get(i).unwrap_or(&f64::NAN))
}) })
.collect() .collect()
} }

View File

@ -206,6 +206,8 @@ message KeptRequest {
uint64 id = 1; uint64 id = 1;
/// Paged attention blocks /// Paged attention blocks
repeated uint32 blocks = 2; repeated uint32 blocks = 2;
/// Paged attention blocks padded to max blocks for this batch
repeated uint32 padded_blocks = 3;
} }
/// kept_requests + terminated_request_ids might not cover all requests from the /// kept_requests + terminated_request_ids might not cover all requests from the

View File

@ -32,7 +32,7 @@ impl BlockAllocation {
self.required_blocks, self.required_blocks,
), ),
}; };
let remaining_blocks = required_blocks - self.allocated_blocks.len(); let remaining_blocks = required_blocks.saturating_sub(self.allocated_blocks.len());
let new_blocks = min(remaining_blocks, 16); let new_blocks = min(remaining_blocks, 16);
// Try to allocate all remaining blocks // Try to allocate all remaining blocks

View File

@ -314,6 +314,9 @@ async fn decode(
// Filter and send finished generations // Filter and send finished generations
let mut filtered_stream_responses = filter_send_ended_generations(generations, entries); let mut filtered_stream_responses = filter_send_ended_generations(generations, entries);
tracing::info!("filtered_stream: {:?}", start_filtering_time.elapsed());
// Send `StreamResponseInfer::Intermediate` messages for entries that don't need to be // Send `StreamResponseInfer::Intermediate` messages for entries that don't need to be
// re-allocated, // re-allocated,
// Allocated new blocks for entries that go over their allocation // Allocated new blocks for entries that go over their allocation
@ -321,17 +324,21 @@ async fn decode(
let (force_update, terminated_entries) = let (force_update, terminated_entries) =
filter_send_update_allocations(entries, &mut filtered_stream_responses); filter_send_update_allocations(entries, &mut filtered_stream_responses);
tracing::info!("filtered_update: {:?}", start_filtering_time.elapsed());
let next_batch = match next_batch { let next_batch = match next_batch {
// Run Only on re-allocation or if entries were filtered // Run Only on re-allocation or if entries were filtered
Some(batch) if batch.size as usize != entries.len() || force_update => { Some(batch) if batch.size as usize != entries.len() || force_update => {
// Filter next batch: remove requests that were stopped and update blocks/slots // Filter next batch: remove requests that were stopped and update blocks/slots
let (filtered_batch, terminated_generations) = let (filtered_batch, terminated_generations) =
filter_batch(client, batch, entries, &terminated_entries).await; filter_batch(client, batch, entries, &terminated_entries).await;
tracing::info!("filter_batch: {:?}", start_filtering_time.elapsed());
send_terminated_generations( send_terminated_generations(
terminated_generations, terminated_generations,
terminated_entries, terminated_entries,
filtered_stream_responses, filtered_stream_responses,
); );
tracing::info!("send_terminated: {:?}", start_filtering_time.elapsed());
filtered_batch filtered_batch
} }
@ -379,23 +386,49 @@ async fn filter_batch(
client.clear_cache(Some(id)).await.unwrap(); client.clear_cache(Some(id)).await.unwrap();
Default::default() Default::default()
} else { } else {
// Collect new blocks/slots let max_blocks = entries
.iter()
.map(|(_, entry)| {
entry
.block_allocation
.as_ref()
.map(|alloc| alloc.blocks().len())
})
.max()
.flatten();
let start_time = Instant::now();
// Collect new blocks
let updated_requests = entries let updated_requests = entries
.iter() .iter()
.map(|(request_id, entry)| { .map(|(request_id, entry)| {
let blocks = entry let (blocks, padded_blocks) = entry
.block_allocation .block_allocation
.as_ref() .as_ref()
.map(|alloc| alloc.blocks().to_vec()) .map(|alloc| {
let max_blocks = match max_blocks {
Some(max_blocks) => max_blocks,
_ => unreachable!(),
};
let blocks = alloc.blocks().to_vec();
let mut padded_blocks = blocks.clone();
padded_blocks.resize(max_blocks - padded_blocks.len(), 0);
(blocks, padded_blocks)
})
.unwrap_or_default(); .unwrap_or_default();
KeptRequest { KeptRequest {
id: *request_id, id: *request_id,
blocks, blocks,
padded_blocks,
} }
}) })
.collect(); .collect();
tracing::info!("Collect blocks: {:?}", start_time.elapsed());
// Filter Python shard cache // Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails // We unwrap here as we need to panic since we cannot recover if this method fails
client client

View File

@ -1,5 +1,5 @@
[toolchain] [toolchain]
# Released on: 02 May, 2024 # Released on: 13 June, 2024
# https://releases.rs/docs/1.78.0/ # https://releases.rs/docs/1.79.0/
channel = "1.78.0" channel = "1.79.0"
components = ["rustfmt", "clippy"] components = ["rustfmt", "clippy"]

View File

@ -403,6 +403,8 @@ class FlashCausalLMBatch(Batch):
kept_requests: List[generate_pb2.KeptRequest], kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int], terminated_request_ids: List[int],
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]: ) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
start = time.time_ns()
terminated_generations = [] terminated_generations = []
for request_id in terminated_request_ids: for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id] idx = self.requests_idx_mapping[request_id]
@ -429,6 +431,11 @@ class FlashCausalLMBatch(Batch):
), ),
) )
) )
from loguru import logger
logger.info(f"terminated generations {(time.time_ns() - start)/1e6}")
if not kept_requests: if not kept_requests:
return None, terminated_generations return None, terminated_generations
@ -445,7 +452,7 @@ class FlashCausalLMBatch(Batch):
requests = [] requests = []
flat_blocks = [] flat_blocks = []
block_tables = [] padded_blocks = []
all_input_ids = [] all_input_ids = []
input_lengths = [] input_lengths = []
@ -483,8 +490,8 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.append(self.top_n_tokens[idx]) top_n_tokens.append(self.top_n_tokens[idx])
request_block_table = request.blocks request_block_table = request.blocks
block_tables.append(request_block_table)
flat_blocks.extend(request_block_table) flat_blocks.extend(request_block_table)
padded_blocks.extend(request.padded_blocks)
# Index # Index
slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1) slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1)
@ -492,6 +499,8 @@ class FlashCausalLMBatch(Batch):
num_blocks += len(request_block_table) num_blocks += len(request_block_table)
max_blocks = max(max_blocks, len(request_block_table)) max_blocks = max(max_blocks, len(request_block_table))
logger.info(f"for loop requests: {(time.time_ns() - start)/1e6}")
# Index into tensors # Index into tensors
input_ids = self.input_ids[indices] input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices] position_ids = self.position_ids[indices]
@ -503,12 +512,14 @@ class FlashCausalLMBatch(Batch):
self.speculative_ids[indices] if self.speculative_ids is not None else None self.speculative_ids[indices] if self.speculative_ids is not None else None
) )
# Create block_tables_tensor on CPU logger.info(f"slice objects: {(time.time_ns() - start)/1e6}")
block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu" # Create block_tables_tensor on GPU
) block_tables_tensor = torch.tensor(
for i, request_blocks in enumerate(block_tables): padded_blocks, dtype=torch.int32, device=device
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) ).view(len(requests), -1)
logger.info(f"allocate block table: {(time.time_ns() - start)/1e6}")
# Allocate on GPU # Allocate on GPU
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
@ -522,6 +533,8 @@ class FlashCausalLMBatch(Batch):
+ torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64) + torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64)
).flatten() ).flatten()
logger.info(f"done allocation: {(time.time_ns() - start)/1e6}")
filtered_batch = type(self)( filtered_batch = type(self)(
batch_id=self.batch_id, batch_id=self.batch_id,
requests=requests, requests=requests,