diff --git a/backends/trtllm/include/ffi.h b/backends/trtllm/include/ffi.h index f4a998b2..449bcd4d 100644 --- a/backends/trtllm/include/ffi.h +++ b/backends/trtllm/include/ffi.h @@ -64,8 +64,6 @@ namespace huggingface::tgi::backends { std::unique_ptr> PullTokens(); }; - GenerationStep ConvertResponseToGenerationStep(const tle::Response &response); - /*** * * @param engineFolder diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp index b9f3d009..b15a4c40 100644 --- a/backends/trtllm/src/ffi.cpp +++ b/backends/trtllm/src/ffi.cpp @@ -36,34 +36,38 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit( std::unique_ptr> huggingface::tgi::backends::TensorRtLlmBackendImpl::PullTokens() { const auto responses = TensorRtLlmBackend::PullNewTokens(); - auto steps = std::make_unique>(responses.size()); - std::ranges::copy(std::views::transform(responses, ConvertResponseToGenerationStep), std::back_inserter(*steps)); - return steps; -} -huggingface::tgi::backends::GenerationStep -huggingface::tgi::backends::ConvertResponseToGenerationStep(const tle::Response &response) { - const auto reqId = response.getRequestId(); - if (!response.hasError()) { - const auto result = response.getResult(); - return std::move(GenerationStep{ - reqId, - static_cast(result.outputTokenIds[0][0]), - result.logProbs.value()[0][0], - result.isFinal, - false, - std::string() - }); - } else { - return std::move(GenerationStep{ - reqId, - 0, - 0.0, - true, - true, - std::move(response.getErrorMsg()) - }); - } + auto steps = std::make_unique>(); + steps->reserve(responses.size()); + + SPDLOG_DEBUG(FMT_STRING("Pulled out {:d} new tokens"), responses->size()); + + // Transform tle::Response to GenerationStep + std::ranges::transform(responses.begin(), responses.end(), std::back_inserter(*steps), [&](const Response &r) { + const auto reqId = r.getRequestId(); + if (!r.hasError()) { + const auto result = r.getResult(); + return GenerationStep{ + reqId, + static_cast(result.outputTokenIds[0][0]), + result.logProbs.value()[0][0], + result.isFinal, + false, + std::string() + }; + } else { + return GenerationStep{ + reqId, + 0, + 0.0, + true, + true, + std::move(r.getErrorMsg()) + }; + } + }); + + return steps; } std::unique_ptr diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index c287fa55..f070bad6 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -8,18 +8,20 @@ use cxx::UniquePtr; use hashbrown::HashMap; use log::warn; use tokenizers::{Encoding, Tokenizer}; -use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::mpsc::error::SendError; -use tokio::task::{JoinHandle, spawn_blocking}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use tokio::task::{spawn_blocking, JoinHandle}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{error, info, Level, span}; +use tracing::{debug, error, info, span, Level}; -use text_generation_router::{FinishReason, Token}; +use text_generation_router::infer::InferError::{GenerationError, ValidationError}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; -use text_generation_router::infer::InferError::GenerationError; -use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest}; -use text_generation_router::validation::ValidationError::UnsupportedModality; +use text_generation_router::validation::ValidationError::{ + EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, +}; +use text_generation_router::validation::{Chunk, ValidGenerateRequest}; +use text_generation_router::{FinishReason, Token}; use crate::errors::TensorRtLlmBackendError; use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl}; @@ -95,6 +97,8 @@ fn executor_status_poller( if backend.num_responses_ready() > 0 { match backend.pin_mut().pull_tokens() { Ok(responses) => { + debug!("Received {} tokens from the executor", responses.len()); + // worse case scenario is one token for each response: with_capacity(responses.len()) // grouper will group decoded tokens per request to decode multiple tokens let mut grouper: HashMap = @@ -102,33 +106,49 @@ fn executor_status_poller( // Iterate through all the decoded token for step in responses.deref() { - let request_id = step.request_id; - - match in_flights.get(&request_id) { + match in_flights.get(&step.request_id) { Some(ctx) => { - info!("New token for {} -> {}", request_id, step.token_id); + debug!( + "{} -> (token={}, final={})", + step.request_id, step.token_id, step.is_final + ); + // If no error, let's forward to post-processor if !step.has_error { - let req_group = grouper.entry(request_id).or_insert( + let req_group = grouper.entry(step.request_id).or_insert( DecodedTokenContext { tokens: vec![], ctx: ctx.streamer.clone(), // Arc::clone() = cheap }, ); req_group.tokens.push(step.clone()); // Should be ultra cheap - - if step.is_final { - let _ = in_flights.remove(&step.request_id); - } } else { warn!( "Error for request: {} -> {}", - request_id, &step.error_msg + step.request_id, &step.error_msg ); + + // TODO: Send something back to the postprocessor for the client? + } + + // Remove from tracked requests + if step.is_final { + let _ = in_flights.remove(&step.request_id); } } None => { - error!("Got step for untracked request {}", request_id); + if step.has_error { + error!( + "Untracked request {} -> {}", + step.request_id, &step.error_msg + ); + continue; + } else { + error!( + "Got step for untracked request {}", + step.request_id + ); + } } } } @@ -275,18 +295,16 @@ impl TensorRtLlmBackendV2 { fn validate(request: &ValidGenerateRequest) -> InferResult<&String> { if request.top_n_tokens > 1 { - return Err(InferError::ValidationError( - ValidationError::TopNTokensDisabled, - )); + return Err(InferError::ValidationError(TopNTokensDisabled)); } // TODO: Is it really needed? How can it be validated before? if request.parameters.grammar.is_some() { - return Err(InferError::ValidationError(ValidationError::Grammar)); + return Err(InferError::ValidationError(Grammar)); } match request.inputs.len() { - 0 => Err(InferError::ValidationError(ValidationError::EmptyInput)), + 0 => Err(InferError::ValidationError(EmptyInput)), 2.. => Err(InferError::GenerationError( "TensorRT-LLM backend don't support multi-chunk".into(), )),