fix(trtllm): fix segfault when canceling request

When a request is cancelled, the `tensorrt_llm::executor::Result`
contains `outputTokenIds` with size 1, but `outputTokenIds[0]` has size
0. This causes `as_generation_step` to segfault.

Check the size of `outputTokenIds` and `logProbs` before attempting to
access the inner vector. The `finishReasons` can be skipped because it
has only one dimension and the minimum beam size is 1.
Because cxx have not added Option support yet, include two boolean flags
to denote whether the value is valid.

Change log level when request is cancelled to debug.
This commit is contained in:
Tzu-Yu Lee 2025-05-18 02:22:53 +08:00
parent cc4b5848b9
commit 0858af206f
3 changed files with 49 additions and 13 deletions

View File

@ -55,13 +55,24 @@ namespace huggingface::tgi::backends::trtllm {
const auto reqId = r.getRequestId();
if (!r.hasError()) [[likely]] {
const auto result = r.getResult();
const auto logits = result.logProbs.value()[0];
std::optional<uint32_t> token_id = std::nullopt;
if (!result.outputTokenIds.empty() && !result.outputTokenIds[0].empty()) {
token_id = static_cast<uint32_t>(result.outputTokenIds[0][0]);
}
std::optional<float> log_prob = std::nullopt;
if (result.logProbs && !result.logProbs->empty() && !result.logProbs.value()[0].empty()) {
log_prob = result.logProbs.value()[0].back();
}
return generation_step_t{
reqId,
static_cast<uint32_t>(result.outputTokenIds[0][0]),
logits.back(),
token_id.value_or(0),
log_prob.value_or(0.0),
result.isFinal,
as_finish_reason_t(result.finishReasons[0]),
token_id.has_value(),
log_prob.has_value(),
false,
std::string()
};
@ -72,6 +83,8 @@ namespace huggingface::tgi::backends::trtllm {
0.0,
true,
finish_reason_t::kNOT_FINISHED,
false,
false,
true,
std::move(r.getErrorMsg())
};

View File

@ -44,6 +44,8 @@ mod ffi {
log_prob: f32,
is_final: bool,
finish_reason: FinishReason,
token_id_valid: bool,
log_prob_valid: bool,
has_error: bool,
error_msg: String,
}

View File

@ -49,16 +49,28 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
type Error = InferError;
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
if !step.has_error {
Ok(Self {
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
finish_reason: step.finish_reason,
})
} else {
Err(GenerationError(step.error_msg.clone()))
if step.has_error {
return Err(GenerationError(step.error_msg.clone()));
}
if !step.token_id_valid {
return Err(GenerationError(
"GenerationStep contains no token_id".to_string(),
));
}
if !step.log_prob_valid {
return Err(GenerationError(
"GenerationStep contains no log_prob".to_string(),
));
}
Ok(Self {
id: step.token_id,
log_prob: step.log_prob,
is_final: step.is_final,
finish_reason: step.finish_reason,
})
}
}
@ -151,7 +163,16 @@ fn executor_status_looper(
let _ = in_flights.remove(&step.request_id);
}
} else {
warn!("Untracked request {}", step.request_id,);
match step.finish_reason {
FinishReason::Cancelled => {
// The client has canceled the request, so this should not generate a
// warning.
debug!("Cancelled request {}", step.request_id);
}
_ => {
warn!("Untracked request {}", step.request_id);
}
}
}
}
}