diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index 5fcc3d33..c874ca64 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -38,13 +38,13 @@ pub(crate) struct Generation { /// Holds the user provided input to be executed along with a channel allowing /// to bubble up all the generated tokens for that tokens the to end stream. -#[derive(Clone)] pub struct GenerationContext { sender: UnboundedSender>, tokenizer: Arc, tokens: Vec, done: Arc, - start: Instant, + queued: Instant, + start: Option, } impl Stream for Generation { @@ -160,7 +160,8 @@ impl TensorRtLlmBackend { tokenizer, tokens: vec![], done: Arc::clone(&generation.done), - start: Instant::now(), + start: None, + queued: Instant::now(), }); // We are leaking the context on-purpose to avoid the box being dropped while there are @@ -198,18 +199,31 @@ impl TensorRtLlmBackend { logprob: f32, is_final: bool| { let inner_ctx = &mut *ctx; + + // Insert the latest generated token to the tracker inner_ctx.tokens.push(token_id); + // Update the timestamp at which the request started effectively + // Can be a bit off, would need to be before the callback, let's see + inner_ctx.start.get_or_insert(Instant::now()); + + // Decode the token let text = inner_ctx .tokenizer .decode(&[token_id], true) .expect("Failed to decode token"); + let special = inner_ctx + .tokenizer + .get_added_vocabulary() + .is_special_token(&text); + + // Create the structure holding the token let token = Token { id: token_id, text, logprob, - special: false, + special, }; let out = if is_final { @@ -228,8 +242,8 @@ impl TensorRtLlmBackend { finish_reason: FinishReason::EndOfSequenceToken, seed: None, }, - start: inner_ctx.start, - queued: Instant::now(), + start: inner_ctx.start.unwrap_or(Instant::now()), + queued: inner_ctx.queued, } } else { InferStreamResponse::Intermediate {