mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
Make prefill time of static benchmark correct (#214)
This commit is contained in:
parent
a8cead1f92
commit
ea48ae169a
@ -163,8 +163,11 @@ async fn prefill(
|
||||
|
||||
// Run prefill
|
||||
let start_time = Instant::now();
|
||||
|
||||
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
|
||||
|
||||
let (_, decode_batch, _) = client.decode(vec![decode_batch.clone().unwrap()]).await?;
|
||||
|
||||
// Get latency
|
||||
let latency = start_time.elapsed();
|
||||
|
||||
@ -180,11 +183,12 @@ async fn prefill(
|
||||
};
|
||||
|
||||
Ok((step, decode_batch))
|
||||
|
||||
}
|
||||
|
||||
/// Run a full decode
|
||||
async fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
|
||||
let mut decode_length = 0;
|
||||
let mut decode_length = 1; // 1 decode step was already scheduled in prefill with speculative scheduling
|
||||
let batch_size = batch.size;
|
||||
|
||||
let start_time = Instant::now();
|
||||
|
@ -51,7 +51,7 @@ struct Args {
|
||||
runs: usize,
|
||||
|
||||
/// Number of warmup cycles
|
||||
#[clap(default_value = "1", short, long, env)]
|
||||
#[clap(default_value = "3", short, long, env)]
|
||||
warmups: usize,
|
||||
|
||||
/// The location of the grpc socket. This benchmark tool bypasses the router
|
||||
|
@ -33,7 +33,6 @@ from transformers import (
|
||||
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
Tokens,
|
||||
|
Loading…
Reference in New Issue
Block a user