From ea48ae169abae05f33747ad99af3c07dfb57ddb2 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Mon, 26 Aug 2024 01:51:28 -0700 Subject: [PATCH] Make prefill time of static benchmark correct (#214) --- benchmark/src/generation.rs | 6 +++++- benchmark/src/main.rs | 2 +- server/text_generation_server/models/causal_lm.py | 1 - 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index ea7c9778..b2766d0c 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -163,8 +163,11 @@ async fn prefill( // Run prefill let start_time = Instant::now(); + let (_, decode_batch, _) = client.prefill(batch.clone()).await?; + let (_, decode_batch, _) = client.decode(vec![decode_batch.clone().unwrap()]).await?; + // Get latency let latency = start_time.elapsed(); @@ -180,11 +183,12 @@ async fn prefill( }; Ok((step, decode_batch)) + } /// Run a full decode async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result { - let mut decode_length = 0; + let mut decode_length = 1; // 1 decode step was already scheduled in prefill with speculative scheduling let batch_size = batch.size; let start_time = Instant::now(); diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 2d89e045..34b91a92 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -51,7 +51,7 @@ struct Args { runs: usize, /// Number of warmup cycles - #[clap(default_value = "1", short, long, env)] + #[clap(default_value = "3", short, long, env)] warmups: usize, /// The location of the grpc socket. This benchmark tool bypasses the router diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index be3b8f4d..c6192be0 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -33,7 +33,6 @@ from transformers import ( from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model -from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, Tokens,