Make prefill time of static benchmark correct (#214)

This commit is contained in:
Sun Choi 2024-08-26 01:51:28 -07:00 committed by GitHub
parent a8cead1f92
commit ea48ae169a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 3 deletions

View File

@ -163,8 +163,11 @@ async fn prefill(
// Run prefill // Run prefill
let start_time = Instant::now(); let start_time = Instant::now();
let (_, decode_batch, _) = client.prefill(batch.clone()).await?; let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
let (_, decode_batch, _) = client.decode(vec![decode_batch.clone().unwrap()]).await?;
// Get latency // Get latency
let latency = start_time.elapsed(); let latency = start_time.elapsed();
@ -180,11 +183,12 @@ async fn prefill(
}; };
Ok((step, decode_batch)) Ok((step, decode_batch))
} }
/// Run a full decode /// Run a full decode
async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> { async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
let mut decode_length = 0; let mut decode_length = 1; // 1 decode step was already scheduled in prefill with speculative scheduling
let batch_size = batch.size; let batch_size = batch.size;
let start_time = Instant::now(); let start_time = Instant::now();

View File

@ -51,7 +51,7 @@ struct Args {
runs: usize, runs: usize,
/// Number of warmup cycles /// Number of warmup cycles
#[clap(default_value = "1", short, long, env)] #[clap(default_value = "3", short, long, env)]
warmups: usize, warmups: usize,
/// The location of the grpc socket. This benchmark tool bypasses the router /// The location of the grpc socket. This benchmark tool bypasses the router

View File

@ -33,7 +33,6 @@ from transformers import (
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
Tokens, Tokens,