mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-31 04:10:16 +00:00
fix(trtllm): fix segfault when canceling request
When a request is cancelled, the `tensorrt_llm::executor::Result` contains `outputTokenIds` with size 1, but `outputTokenIds[0]` has size 0. This causes `as_generation_step` to segfault. Check the size of `outputTokenIds` and `logProbs` before attempting to access the inner vector. The `finishReasons` can be skipped because it has only one dimension and the minimum beam size is 1. Because cxx have not added Option support yet, include two boolean flags to denote whether the value is valid. Change log level when request is cancelled to debug.
This commit is contained in:
parent
cc4b5848b9
commit
0858af206f
@ -55,13 +55,24 @@ namespace huggingface::tgi::backends::trtllm {
|
||||
const auto reqId = r.getRequestId();
|
||||
if (!r.hasError()) [[likely]] {
|
||||
const auto result = r.getResult();
|
||||
const auto logits = result.logProbs.value()[0];
|
||||
std::optional<uint32_t> token_id = std::nullopt;
|
||||
if (!result.outputTokenIds.empty() && !result.outputTokenIds[0].empty()) {
|
||||
token_id = static_cast<uint32_t>(result.outputTokenIds[0][0]);
|
||||
}
|
||||
|
||||
std::optional<float> log_prob = std::nullopt;
|
||||
if (result.logProbs && !result.logProbs->empty() && !result.logProbs.value()[0].empty()) {
|
||||
log_prob = result.logProbs.value()[0].back();
|
||||
}
|
||||
|
||||
return generation_step_t{
|
||||
reqId,
|
||||
static_cast<uint32_t>(result.outputTokenIds[0][0]),
|
||||
logits.back(),
|
||||
token_id.value_or(0),
|
||||
log_prob.value_or(0.0),
|
||||
result.isFinal,
|
||||
as_finish_reason_t(result.finishReasons[0]),
|
||||
token_id.has_value(),
|
||||
log_prob.has_value(),
|
||||
false,
|
||||
std::string()
|
||||
};
|
||||
@ -72,6 +83,8 @@ namespace huggingface::tgi::backends::trtllm {
|
||||
0.0,
|
||||
true,
|
||||
finish_reason_t::kNOT_FINISHED,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
std::move(r.getErrorMsg())
|
||||
};
|
||||
|
@ -44,6 +44,8 @@ mod ffi {
|
||||
log_prob: f32,
|
||||
is_final: bool,
|
||||
finish_reason: FinishReason,
|
||||
token_id_valid: bool,
|
||||
log_prob_valid: bool,
|
||||
has_error: bool,
|
||||
error_msg: String,
|
||||
}
|
||||
|
@ -49,16 +49,28 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
|
||||
type Error = InferError;
|
||||
|
||||
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
|
||||
if !step.has_error {
|
||||
Ok(Self {
|
||||
id: step.token_id,
|
||||
log_prob: step.log_prob,
|
||||
is_final: step.is_final,
|
||||
finish_reason: step.finish_reason,
|
||||
})
|
||||
} else {
|
||||
Err(GenerationError(step.error_msg.clone()))
|
||||
if step.has_error {
|
||||
return Err(GenerationError(step.error_msg.clone()));
|
||||
}
|
||||
|
||||
if !step.token_id_valid {
|
||||
return Err(GenerationError(
|
||||
"GenerationStep contains no token_id".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if !step.log_prob_valid {
|
||||
return Err(GenerationError(
|
||||
"GenerationStep contains no log_prob".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
id: step.token_id,
|
||||
log_prob: step.log_prob,
|
||||
is_final: step.is_final,
|
||||
finish_reason: step.finish_reason,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -151,7 +163,16 @@ fn executor_status_looper(
|
||||
let _ = in_flights.remove(&step.request_id);
|
||||
}
|
||||
} else {
|
||||
warn!("Untracked request {}", step.request_id,);
|
||||
match step.finish_reason {
|
||||
FinishReason::Cancelled => {
|
||||
// The client has canceled the request, so this should not generate a
|
||||
// warning.
|
||||
debug!("Cancelled request {}", step.request_id);
|
||||
}
|
||||
_ => {
|
||||
warn!("Untracked request {}", step.request_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user