From 0858af206f61527be04ce38c7680a456f2885b71 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Sun, 18 May 2025 02:22:53 +0800 Subject: [PATCH] fix(trtllm): fix segfault when canceling request When a request is cancelled, the `tensorrt_llm::executor::Result` contains `outputTokenIds` with size 1, but `outputTokenIds[0]` has size 0. This causes `as_generation_step` to segfault. Check the size of `outputTokenIds` and `logProbs` before attempting to access the inner vector. The `finishReasons` can be skipped because it has only one dimension and the minimum beam size is 1. Because cxx have not added Option support yet, include two boolean flags to denote whether the value is valid. Change log level when request is cancelled to debug. --- backends/trtllm/csrc/ffi.hpp | 19 +++++++++++++--- backends/trtllm/src/lib.rs | 2 ++ backends/trtllm/src/looper.rs | 41 ++++++++++++++++++++++++++--------- 3 files changed, 49 insertions(+), 13 deletions(-) diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index a877df5a..90d1b9d1 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -55,13 +55,24 @@ namespace huggingface::tgi::backends::trtllm { const auto reqId = r.getRequestId(); if (!r.hasError()) [[likely]] { const auto result = r.getResult(); - const auto logits = result.logProbs.value()[0]; + std::optional token_id = std::nullopt; + if (!result.outputTokenIds.empty() && !result.outputTokenIds[0].empty()) { + token_id = static_cast(result.outputTokenIds[0][0]); + } + + std::optional log_prob = std::nullopt; + if (result.logProbs && !result.logProbs->empty() && !result.logProbs.value()[0].empty()) { + log_prob = result.logProbs.value()[0].back(); + } + return generation_step_t{ reqId, - static_cast(result.outputTokenIds[0][0]), - logits.back(), + token_id.value_or(0), + log_prob.value_or(0.0), result.isFinal, as_finish_reason_t(result.finishReasons[0]), + token_id.has_value(), + log_prob.has_value(), false, std::string() }; @@ -72,6 +83,8 @@ namespace huggingface::tgi::backends::trtllm { 0.0, true, finish_reason_t::kNOT_FINISHED, + false, + false, true, std::move(r.getErrorMsg()) }; diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 52e48f91..b2a9274d 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -44,6 +44,8 @@ mod ffi { log_prob: f32, is_final: bool, finish_reason: FinishReason, + token_id_valid: bool, + log_prob_valid: bool, has_error: bool, error_msg: String, } diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 5fed954f..fd0bc967 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -49,16 +49,28 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { type Error = InferError; fn try_from(step: &'step GenerationStep) -> Result { - if !step.has_error { - Ok(Self { - id: step.token_id, - log_prob: step.log_prob, - is_final: step.is_final, - finish_reason: step.finish_reason, - }) - } else { - Err(GenerationError(step.error_msg.clone())) + if step.has_error { + return Err(GenerationError(step.error_msg.clone())); } + + if !step.token_id_valid { + return Err(GenerationError( + "GenerationStep contains no token_id".to_string(), + )); + } + + if !step.log_prob_valid { + return Err(GenerationError( + "GenerationStep contains no log_prob".to_string(), + )); + } + + Ok(Self { + id: step.token_id, + log_prob: step.log_prob, + is_final: step.is_final, + finish_reason: step.finish_reason, + }) } } @@ -151,7 +163,16 @@ fn executor_status_looper( let _ = in_flights.remove(&step.request_id); } } else { - warn!("Untracked request {}", step.request_id,); + match step.finish_reason { + FinishReason::Cancelled => { + // The client has canceled the request, so this should not generate a + // warning. + debug!("Cancelled request {}", step.request_id); + } + _ => { + warn!("Untracked request {}", step.request_id); + } + } } } }