diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index d4113725..5d8f53b6 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -15,7 +15,7 @@ use tokio::sync::RwLock; use tokio::time::{Instant, sleep}; use tokio_stream::{Stream, StreamExt}; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{instrument, Level, span}; +use tracing::{instrument, Level, span, Span}; use text_generation_router::{FinishReason, Token}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; @@ -25,42 +25,23 @@ use text_generation_router::validation::ValidationError::UnsupportedModality; use crate::errors::TensorRtLlmBackendError; use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl}; -// macro_rules! propagate { -// ($ctx: expr, $res: expr) => { -// $ctx.sender -// .send($res) -// .expect("Failed to propagate error back to the transport layer") -// }; -// } - type InferResult = Result; -/// Holds the user provided input to be executed along with a channel allowing -/// to bubble up all the generated tokens for that tokens the to end stream. -// pub struct InferenceContext { -// /// User provided request -// request: ValidGenerateRequest, -// -// /// Inter-process communication handler moving token from the executor thread to the HTTP server -// sender: UnboundedSender>, -// -// /// Pin the instant this inference context was submitted -// when: Instant, -// -// /// Span that will live as long as entry -// span: Span, -// } - pub(crate) struct Generation { executor: Arc>>, done: Arc, } +/// Holds the user provided input to be executed along with a channel allowing +/// to bubble up all the generated tokens for that tokens the to end stream. #[derive(Clone)] pub struct GenerationContext { sender: UnboundedSender>, tokenizer: Arc, + tokens: Vec, done: Arc, + start: Instant, + span: Span, } impl Stream for Generation { @@ -175,7 +156,10 @@ impl TensorRtLlmBackend { let ctx = Box::new(GenerationContext { sender: sender.clone(), tokenizer: tokenizer.clone(), + tokens: vec![], done: Arc::clone(&generation.done), + start: Instant::now(), + span: Span::current(), }); // We are leaking the context on-purpose to avoid the box being dropped while there are @@ -209,45 +193,50 @@ impl TensorRtLlmBackend { request_id, ctx_, |ctx: *mut GenerationContext, - token: u32, + token_id: u32, logprob: f32, is_final: bool| { - // let text = ctx - // .tokenizer - // .decode(&[token], true) - // .expect("Failed to decode token"); - info!("Decoded token: {}", token); + let inner_ctx = &mut *ctx; + inner_ctx.tokens.push(token_id); + + let text = inner_ctx + .tokenizer + .decode(&[token_id], true) + .expect("Failed to decode token"); + + let token = Token { + id: token_id, + text, + logprob, + special: false, + }; + let out = if is_final { - (*ctx).done.store(true, Ordering::Relaxed); + inner_ctx.done.store(true, Ordering::Relaxed); + let generated_text = inner_ctx + .tokenizer + .decode(&inner_ctx.tokens, true) + .expect("Failed to decode generated_tokens"); + InferStreamResponse::End { - token: Token { - id: token, - text: "".into(), - logprob, - special: false, - }, + token, top_tokens: vec![], generated_text: GeneratedText { - text: "".into(), - generated_tokens: u32::MAX, + text: generated_text, + generated_tokens: inner_ctx.tokens.len() as u32, finish_reason: FinishReason::EndOfSequenceToken, seed: None, }, - start: Instant::now(), + start: inner_ctx.start, queued: Instant::now(), } } else { InferStreamResponse::Intermediate { - token: Token { - id: token, - text: "".into(), - logprob, - special: false, - }, + token, top_tokens: vec![], } }; - (*ctx) + inner_ctx .sender .send(Ok(out)) .expect("Failed to send back generated token");