feat: add more latency metrics in forward (#1346)

This commit is contained in:
OlivierDehaene 2023-12-14 15:59:38 +01:00 committed by Karol Damaszke
parent c974437ba7
commit 5c9ef069ed
17 changed files with 242 additions and 112 deletions

View File

@ -163,7 +163,7 @@ async fn prefill(
// Run prefill
let start_time = Instant::now();
let (_, decode_batch) = client.prefill(batch.clone()).await?;
let (_, decode_batch, _) = client.prefill(batch.clone()).await?;
// Get latency
let latency = start_time.elapsed();

View File

@ -182,6 +182,12 @@ message PrefillResponse {
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
}
message DecodeRequest {
@ -194,6 +200,14 @@ message DecodeResponse {
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
/// Forward elapsed time in nanoseconds
uint64 forward_ns = 3;
/// Decode elapsed time in nanoseconds
uint64 decode_ns = 4;
/// Total elapsed time in nanoseconds
uint64 total_ns = 5;
/// Concatenate elapsed time in nanoseconds
optional uint64 concat_ns = 6;
}
message WarmupRequest {

View File

@ -8,6 +8,7 @@ use std::env;
use rand::{distributions::Uniform, Rng};
use grpc_metadata::InjectTelemetryContext;
use std::cmp;
use std::time::Duration;
use tonic::transport::{Channel, Uri};
use tracing::instrument;
@ -294,10 +295,14 @@ impl Client {
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((response.generations, response.batch))
Ok((
response.generations,
response.batch,
PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),
))
}
/// Generate one token for each request in the given cached batches
@ -308,9 +313,52 @@ impl Client {
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let request = tonic::Request::new(DecodeRequest { batches }).inject_context();
let response = self.stub.decode(request).await?.into_inner();
Ok((response.generations, response.batch))
Ok((
response.generations,
response.batch,
DecodeTimings::new(
response.concat_ns,
response.forward_ns,
response.decode_ns,
response.total_ns,
),
))
}
}
pub struct PrefillTimings {
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl PrefillTimings {
fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}
pub struct DecodeTimings {
pub concat: Option<Duration>,
pub forward: Duration,
pub decode: Duration,
pub total: Duration,
}
impl DecodeTimings {
fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {
Self {
concat: concat_ns.map(|v| Duration::from_nanos(v)),
forward: Duration::from_nanos(forward_ns),
decode: Duration::from_nanos(decode_ns),
total: Duration::from_nanos(total_ns),
}
}
}

View File

@ -1,5 +1,6 @@
/// Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
use crate::client::{DecodeTimings, PrefillTimings};
/// Multi shard Client
use crate::{Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
use crate::{ClientError, Result};
@ -119,49 +120,63 @@ impl ShardedClient {
///
/// Returns Generation for each request in batch
/// and the next cached batch
#[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]
#[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]
pub async fn prefill(
&mut self,
batch: Batch,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.collect();
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =
join_all(futures).await.into_iter().collect();
merge_generations(results?)
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
/// Generate one token for each request in the given cached batches
///
/// Returns Generation for each request in batches
/// and the next cached batch
#[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]
#[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]
pub async fn decode(
&mut self,
batches: Vec<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.decode(batches.clone())))
.collect();
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =
join_all(futures).await.into_iter().collect();
merge_generations(results?)
let mut results = results?;
let (mut generations, next_batch, mut timings) =
results.pop().ok_or(ClientError::EmptyResults)?;
// Merge generations from different model shards
for (mut shard_generations, _, shard_timings) in results.into_iter() {
generations.append(&mut shard_generations);
// Return the timings of the slowest shard
if shard_timings.total > timings.total {
timings = shard_timings;
}
}
Ok((generations, next_batch, timings))
}
}
/// Merge generations from the different model shards
fn merge_generations(
mut results: Vec<(Vec<Generation>, Option<CachedBatch>)>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
for (mut shard_generations, _) in results.into_iter() {
generations.append(&mut shard_generations);
}
Ok((generations, next_batch))
}

View File

@ -390,15 +390,20 @@ async fn prefill(
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
match client.prefill(batch).await {
Ok((generations, next_batch)) => {
Ok((generations, next_batch, timings)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
next_batch
@ -427,15 +432,23 @@ async fn decode(
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
match client.decode(batches).await {
Ok((generations, next_batch)) => {
Ok((generations, next_batch, timings)) => {
// Update health
generation_health.store(true, Ordering::SeqCst);
let start_filtering_time = Instant::now();
// Send generated tokens and filter stopped entries
filter_send_generations(generations, entries);
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
}
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
next_batch

View File

@ -569,7 +569,7 @@ mod tests {
let max_stop_sequence = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let max_total_tokens = 106;
let workers = 1;
let validation = Validation::new(
workers,
@ -629,7 +629,7 @@ mod tests {
let max_stop_sequences = 3;
let max_top_n_tokens = 4;
let max_input_length = 5;
let max_total_tokens = 6;
let max_total_tokens = 106;
let workers = 1;
let validation = Validation::new(
workers,

View File

@ -105,7 +105,7 @@ def test_causal_lm_batch_type(default_bloom):
@pytest.mark.skip
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
sequence_length = len(default_bloom_batch.all_input_ids[0])
generations, next_batch = default_bloom.generate_token(default_bloom_batch)
generations, next_batch, _ = default_bloom.generate_token(default_bloom_batch)
assert len(generations) == len(default_bloom_batch)
assert isinstance(next_batch, CausalLMBatch)
@ -156,10 +156,10 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):
next_batch = default_bloom_batch
for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_bloom_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -182,10 +182,10 @@ def test_causal_lm_generate_token_completion_multi(
for i in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(default_multi_requests_bloom_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -205,10 +205,10 @@ def test_causal_lm_generate_token_completion_multi(
for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -229,11 +229,11 @@ def test_batch_concatenate(
default_bloom, default_bloom_batch, default_multi_requests_bloom_batch
):
next_batch_0 = default_bloom_batch
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
_, next_batch_0 = default_bloom.generate_token(next_batch_0)
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
_, next_batch_0, _ = default_bloom.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_bloom_batch
_, next_batch_1 = default_bloom.generate_token(next_batch_1)
_, next_batch_1, _ = default_bloom.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
@ -293,10 +293,10 @@ def test_batch_concatenate(
for _ in range(
default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2
):
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
@ -318,10 +318,10 @@ def test_batch_concatenate(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 2
):
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -342,10 +342,10 @@ def test_batch_concatenate(
- default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
- 4
):
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_bloom.generate_token(next_batch)
generations, next_batch, _ = default_bloom.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1

View File

@ -111,7 +111,9 @@ def test_causal_lm_batch_type(default_causal_lm):
def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
sequence_length = len(default_causal_lm_batch.all_input_ids[0])
generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch)
generations, next_batch, _ = default_causal_lm.generate_token(
default_causal_lm_batch
)
assert len(generations) == len(next_batch)
assert isinstance(next_batch, CausalLMBatch)
@ -163,10 +165,10 @@ def test_causal_lm_generate_token_completion(
):
next_batch = default_causal_lm_batch
for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -186,10 +188,10 @@ def test_causal_lm_generate_token_completion_multi(
for i in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -212,10 +214,10 @@ def test_causal_lm_generate_token_completion_multi(
for _ in range(
stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -234,11 +236,11 @@ def test_batch_concatenate(
default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch
):
next_batch_0 = default_causal_lm_batch
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
_, next_batch_0 = default_causal_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_causal_lm_batch
_, next_batch_1 = default_causal_lm.generate_token(next_batch_1)
_, next_batch_1, _ = default_causal_lm.generate_token(next_batch_1)
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
@ -297,10 +299,10 @@ def test_batch_concatenate(
for _ in range(
default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
@ -323,10 +325,10 @@ def test_batch_concatenate(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 2
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -345,10 +347,10 @@ def test_batch_concatenate(
- default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
- 4
):
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_causal_lm.generate_token(next_batch)
generations, next_batch, _ = default_causal_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1

View File

@ -55,10 +55,10 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch = default_santacoder.generate_token(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_santacoder.generate_token(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -83,10 +83,10 @@ def test_fim_santacoder_generate_token_completion(
next_batch = batch
for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):
generations, next_batch = default_santacoder.generate_token(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_santacoder.generate_token(next_batch)
generations, next_batch, _ = default_santacoder.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1

View File

@ -107,7 +107,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
@pytest.mark.skip("seq2seq model not enabled on HPU yet")
def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):
sequence_length = len(default_seq2seq_lm_batch.input_ids[0])
generations, next_batch = default_seq2seq_lm.generate_token(
generations, next_batch, _ = default_seq2seq_lm.generate_token(
default_seq2seq_lm_batch
)
@ -178,10 +178,10 @@ def test_seq2seq_lm_generate_token_completion(
):
next_batch = default_seq2seq_lm_batch
for _ in range(6):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -197,10 +197,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = default_multi_requests_seq2seq_lm_batch
for i in range(4):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -213,10 +213,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch = next_batch.filter([next_batch.requests[0].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1
@ -235,11 +235,11 @@ def test_batch_concatenate(
default_multi_requests_seq2seq_lm_batch,
):
next_batch_0 = default_seq2seq_lm_batch
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0 = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
_, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)
next_batch_1 = default_multi_requests_seq2seq_lm_batch
_, next_batch_1 = default_seq2seq_lm.generate_token(next_batch_1)
_, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)
# Copy hidden state because it is removed from the concatenated branches
next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state
@ -331,10 +331,10 @@ def test_batch_concatenate(
)
for _ in range(3):
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert len(generations) == len(next_batch)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 3
@ -349,7 +349,7 @@ def test_batch_concatenate(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
assert len(generations) == 2
@ -359,7 +359,7 @@ def test_batch_concatenate(
next_batch = next_batch.filter([next_batch.requests[1].id])
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is None
assert len(generations) == 1

View File

@ -7,6 +7,7 @@ import itertools
import math
import os
import tempfile
import time
from typing import Dict, List, Optional, Tuple, Type
import torch
@ -33,6 +34,7 @@ from transformers import (
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import (
Batch,
Tokens,
@ -821,7 +823,10 @@ class CausalLM(Model):
return outputs.logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token")
def generate_token(self, batches: List[CausalLMBatch]) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
def generate_token(
self, batches: List[CausalLMBatch]
) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
# Results
generations: List[Generation] = []
prev_batches = []
@ -939,6 +944,8 @@ class CausalLM(Model):
htorch.core.mark_step()
start_decode = time.time_ns()
# Stage 3. Finish and return previous generations
stopped = len(requests_to_generate) > 0
for prev_batch in prev_batches:
@ -1073,13 +1080,16 @@ class CausalLM(Model):
self.hb_profiler.stop()
else:
self.hb_profiler.step()
return generations, batch if not stopped else None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch if not stopped else None, (forward_ns, decode_ns)
def warmup(self, batches: List[CausalLMBatch]) -> None:
# prefill
_, prefill_batch = self.generate_token([batches.pop(0)])
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
# decode
_, decode_batch = self.generate_token([prefill_batch])
_, decode_batch, _ = self.generate_token([prefill_batch])
# shifts
self.shifting_warmup(decode_batch)
@ -1088,12 +1098,12 @@ class CausalLM(Model):
return
# prefill
_, prefill_batch = self.generate_token([batches.pop(0)])
_, prefill_batch, _ = self.generate_token([batches.pop(0)])
# concatenate and decode
_, decode_batch = self.generate_token([decode_batch, prefill_batch])
_, decode_batch, _ = self.generate_token([decode_batch, prefill_batch])
# decodes
while decode_batch is not None:
_, decode_batch = self.generate_token([decode_batch])
_, decode_batch, _ = self.generate_token([decode_batch])
def shifting_warmup(self, batch: CausalLMBatch) -> None:
chunk_sizes = CHUNK_SIZES.copy()

View File

@ -1,6 +1,6 @@
import math
import time
import itertools
from text_generation_server.utils.tokens import batch_top_tokens
import torch
import torch.distributed
@ -9,9 +9,10 @@ import numpy as np
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.models.types import (
Batch,
@ -689,7 +690,7 @@ class FlashCausalLM(Model):
self.dtype,
self.device,
)
_, batch = self.generate_token(batch)
_, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
@ -799,7 +800,8 @@ class FlashCausalLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
@ -941,6 +943,8 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist()
start_decode = time.time_ns()
# Zipped iterator
iterator = zip(
@ -977,7 +981,6 @@ class FlashCausalLM(Model):
# Append next token to all tokens
next_token_texts = []
left = 0
before = stopping_criteria.current_tokens
current_stopped = False
for j in range(index, index + n_accepted_ids):
@ -1092,7 +1095,7 @@ class FlashCausalLM(Model):
generations.append(generation)
# Update values
batch.input_lengths[i] = input_length + n_accepted_ids.item()
batch.input_lengths[i] = input_length + n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i]
batch.prefix_offsets[i] = prefix_offset
@ -1102,10 +1105,14 @@ class FlashCausalLM(Model):
if stopped:
del batch
# No need to return a batch if we know that all requests stopped
return generations, None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
return generations, batch
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -1,17 +1,11 @@
import torch
import inspect
import re
from io import BytesIO
import base64
from PIL import Image
import re
import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import (
AutoProcessor,
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedTokenizerBase,
ProcessorMixin,
)
@ -670,7 +664,8 @@ class IdeficsCausalLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: IdeficsCausalLMBatch
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch]]:
) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]:
start = time.time_ns()
# slice the attention mask to the correct shape
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
if batch.input_ids.size(1) == 1:
@ -699,6 +694,8 @@ class IdeficsCausalLM(Model):
# Hardcoded remove image tokens
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
start_decode = time.time_ns()
# Results
generations: List[Generation] = []
stopped = True
@ -827,7 +824,9 @@ class IdeficsCausalLM(Model):
# We finished all generations in the batch; there is no next batch
if stopped:
return generations, None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# Slice unused values from prefill
batch.input_ids = batch.input_ids[:, :1]
@ -847,4 +846,6 @@ class IdeficsCausalLM(Model):
batch.past_key_values = past
batch.image_hidden_states = image_hidden_states
return generations, batch
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -60,7 +60,9 @@ class Model(ABC):
raise NotImplementedError
@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
def generate_token(
self, batch: B
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
raise NotImplementedError
def warmup(self, batch: B, max_total_tokens: int):

View File

@ -1,11 +1,12 @@
from text_generation_server.utils.tokens import batch_top_tokens
import torch
import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
from text_generation_server.models.types import (
GeneratedText,
@ -613,7 +614,8 @@ class Seq2SeqLM(Model):
@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: Seq2SeqLMBatch
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:
start = time.time_ns()
if batch.decoder_attention_mask is not None:
# slice to the correct shape
decoder_attention_mask = batch.decoder_attention_mask[
@ -644,6 +646,8 @@ class Seq2SeqLM(Model):
torch.log_softmax(logits[:, -1], -1),
)
start_decode = time.time_ns()
# Finished requests
generations: List[Generation] = []
stopped = True
@ -788,7 +792,9 @@ class Seq2SeqLM(Model):
# We finished all generations in the batch; there is no next batch
if stopped:
return generations, None
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
# We don't need input_ids after the prefill forward
batch.input_ids = None
@ -799,4 +805,6 @@ class Seq2SeqLM(Model):
batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1
batch.padding_right_offset -= 1
return generations, batch
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)

View File

@ -4,6 +4,7 @@ import asyncio
import os
import sys
import torch
import time
from grpc import aio
from loguru import logger
@ -70,18 +71,23 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.WarmupResponse()
async def Prefill(self, request, context):
start = time.time_ns()
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
generations, next_batch = self.model.generate_token([batch])
generations, next_batch, timings = self.model.generate_token([batch])
self.cache.set(next_batch)
return generate_pb2.PrefillResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
forward_ns=timings[0],
decode_ns=timings[1],
total_ns=time.time_ns() - start,
)
async def Decode(self, request, context):
start = time.time_ns()
if len(request.batches) == 0:
raise ValueError("Must provide at least one batch")
@ -95,12 +101,16 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) == 0:
raise ValueError("All batches are empty")
generations, next_batch = self.model.generate_token(batches)
generations, next_batch, timings = self.model.generate_token(batches)
self.cache.set(next_batch)
return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None,
concat_ns=None, # TODO: measure concat time
forward_ns=timings[0],
decode_ns=timings[1],
total_ns=time.time_ns() - start,
)

View File

@ -89,7 +89,7 @@ class NextTokenChooser:
class StopSequenceCriteria:
def __init__(self, stop_sequence: str):
stop_sequence = re.escape(stop_sequence)
self.regex = re.compile(f".*{stop_sequence}$")
self.regex = re.compile(f"{stop_sequence}$")
def __call__(self, output: str) -> bool:
if self.regex.findall(output):