diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h index 63bb4f19..fe0be9fc 100644 --- a/backends/trtllm/include/ffi.h +++ b/backends/trtllm/include/ffi.h @@ -60,7 +60,8 @@ namespace huggingface::tgi::backends { size_t StreamTokens( const RequestId requestId, huggingface::tgi::backends::GenerationContext *ctx, - rust::Fn callback); + rust::Fn callback); }; /*** diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index 5e7a2098..1c4878e1 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -20,13 +20,11 @@ use tracing::{instrument, Level, span}; use text_generation_router::{FinishReason, Token}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; -use text_generation_router::validation::{ - Chunk, ValidationError, ValidGenerateRequest, ValidParameters, -}; +use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest}; use text_generation_router::validation::ValidationError::UnsupportedModality; use crate::errors::TensorRtLlmBackendError; -use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl}; +use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; // Value used to poll the state of the generation stream static POLLING_INTERVAL_US: OnceLock = OnceLock::new(); @@ -208,14 +206,11 @@ impl TensorRtLlmBackend { executor_w.pin_mut().stream_tokens( request_id, ctx_, - |ctx: *mut GenerationContext, - token_id: u32, - logprob: f32, - is_final: bool| { + |ctx: *mut GenerationContext, step: GenerationStep| { let inner_ctx = &mut *ctx; // Insert the latest generated token to the tracker - inner_ctx.tokens.push(token_id); + inner_ctx.tokens.push(step.token_id); // Update the timestamp at which the request started effectively // Can be a bit off, would need to be before the callback, let's see @@ -224,7 +219,7 @@ impl TensorRtLlmBackend { // Decode the token let text = inner_ctx .tokenizer - .decode(&[token_id], true) + .decode(&[step.token_id], true) .expect("Failed to decode token"); let special = inner_ctx @@ -234,13 +229,13 @@ impl TensorRtLlmBackend { // Create the structure holding the token let token = Token { - id: token_id, + id: step.token_id, text, - logprob, + logprob: step.log_prob, special, }; - let out = if is_final { + let out = if step.is_final { inner_ctx.done.store(true, Ordering::Relaxed); let generated_text = inner_ctx .tokenizer diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp index a4433f2d..017d0121 100644 --- a/backends/trtllm/src/ffi.cpp +++ b/backends/trtllm/src/ffi.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -36,10 +37,12 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( const uint64_t requestId, huggingface::tgi::backends::GenerationContext *ctx, - rust::Fn callback) { + rust::Fn callback) { size_t numTokens = 0; for (const auto &item: Poll(requestId)) { + GenerationStep step; if (!item.hasError()) { SPDLOG_DEBUG("\tStreamTokens -> Decoding token..."); const auto decoded = item.getResult(); @@ -51,13 +54,15 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( ++numTokens; SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); - callback(std::move(ctx), token, logProb, isFinal); + step = huggingface::tgi::backends::GenerationStep{static_cast(token), logProb, isFinal}; SPDLOG_DEBUG("\tStreamTokens -> Post callback"); } else { // TODO : Return rest::Result with error SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg()); - callback(std::move(ctx), 0, 0.0, true); + step = huggingface::tgi::backends::GenerationStep{std::numeric_limits::max(), 0.0, true}; } + + callback(std::move(ctx), std::move(step)); } return numTokens; diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index d47a4b43..4b1ff751 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -1,12 +1,20 @@ -pub use backend::TensorRtLlmBackend; - -use crate::backend::GenerationContext; +pub use backend::{GenerationContext, TensorRtLlmBackend}; mod backend; pub mod errors; #[cxx::bridge(namespace = "huggingface::tgi::backends")] mod ffi { + + /// Struct used as shared type between rust and C++ to represent the result + /// of a single decoding iteration + #[derive(Copy, Clone)] + pub struct GenerationStep { + token_id: u32, + log_prob: f32, + is_final: bool, + } + extern "Rust" { type GenerationContext; } @@ -60,7 +68,7 @@ mod ffi { self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64, ctx: *mut GenerationContext, - cb: unsafe fn(*mut GenerationContext, u32, f32, bool), + cb: unsafe fn(*mut GenerationContext, GenerationStep), ) -> usize; // #[rust_name = "shutdown"]