diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index a877df5a..90d1b9d1 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -55,13 +55,24 @@ namespace huggingface::tgi::backends::trtllm { const auto reqId = r.getRequestId(); if (!r.hasError()) [[likely]] { const auto result = r.getResult(); - const auto logits = result.logProbs.value()[0]; + std::optional token_id = std::nullopt; + if (!result.outputTokenIds.empty() && !result.outputTokenIds[0].empty()) { + token_id = static_cast(result.outputTokenIds[0][0]); + } + + std::optional log_prob = std::nullopt; + if (result.logProbs && !result.logProbs->empty() && !result.logProbs.value()[0].empty()) { + log_prob = result.logProbs.value()[0].back(); + } + return generation_step_t{ reqId, - static_cast(result.outputTokenIds[0][0]), - logits.back(), + token_id.value_or(0), + log_prob.value_or(0.0), result.isFinal, as_finish_reason_t(result.finishReasons[0]), + token_id.has_value(), + log_prob.has_value(), false, std::string() }; @@ -72,6 +83,8 @@ namespace huggingface::tgi::backends::trtllm { 0.0, true, finish_reason_t::kNOT_FINISHED, + false, + false, true, std::move(r.getErrorMsg()) }; diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 52e48f91..b2a9274d 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -44,6 +44,8 @@ mod ffi { log_prob: f32, is_final: bool, finish_reason: FinishReason, + token_id_valid: bool, + log_prob_valid: bool, has_error: bool, error_msg: String, } diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 5fed954f..fd0bc967 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -49,16 +49,28 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { type Error = InferError; fn try_from(step: &'step GenerationStep) -> Result { - if !step.has_error { - Ok(Self { - id: step.token_id, - log_prob: step.log_prob, - is_final: step.is_final, - finish_reason: step.finish_reason, - }) - } else { - Err(GenerationError(step.error_msg.clone())) + if step.has_error { + return Err(GenerationError(step.error_msg.clone())); } + + if !step.token_id_valid { + return Err(GenerationError( + "GenerationStep contains no token_id".to_string(), + )); + } + + if !step.log_prob_valid { + return Err(GenerationError( + "GenerationStep contains no log_prob".to_string(), + )); + } + + Ok(Self { + id: step.token_id, + log_prob: step.log_prob, + is_final: step.is_final, + finish_reason: step.finish_reason, + }) } } @@ -151,7 +163,16 @@ fn executor_status_looper( let _ = in_flights.remove(&step.request_id); } } else { - warn!("Untracked request {}", step.request_id,); + match step.finish_reason { + FinishReason::Cancelled => { + // The client has canceled the request, so this should not generate a + // warning. + debug!("Cancelled request {}", step.request_id); + } + _ => { + warn!("Untracked request {}", step.request_id); + } + } } } }