expose information about potential error happening while decoding

This commit is contained in:
Morgan Funtowicz 2024-07-18 22:07:59 +00:00
parent a19d318947
commit e82dc30e8a
3 changed files with 67 additions and 53 deletions

View File

@ -197,24 +197,28 @@ impl TensorRtLlmBackend {
.await; .await;
while let Some(_) = generation.next().await { while let Some(_) = generation.next().await {
span!(Level::DEBUG, "decode", request_id = request_id)
.in_scope(|| async {
let mut executor_w = executor.write().await; let mut executor_w = executor.write().await;
let executor = executor_w.pin_mut();
unsafe {
debug!("Acquired write lock stream"); debug!("Acquired write lock stream");
executor_w.pin_mut().stream_tokens( span!(Level::DEBUG, "decode")
.in_scope(|| async {
unsafe {
executor.stream_tokens(
request_id, request_id,
ctx_, ctx_,
|ctx: *mut GenerationContext, step: GenerationStep| { |ctx: *mut GenerationContext, step: GenerationStep| {
let inner_ctx = &mut *ctx; let inner_ctx = &mut *ctx;
// Insert the latest generated token to the tracker
inner_ctx.tokens.push(step.token_id);
// Update the timestamp at which the request started effectively // Update the timestamp at which the request started effectively
// Can be a bit off, would need to be before the callback, let's see // Can be a bit off, would need to be before the callback, let's see
inner_ctx.start.get_or_insert(Instant::now()); inner_ctx.start.get_or_insert(Instant::now());
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
// Ensure we are not running into errors
let parcel = if !step.has_error {
// Insert the latest generated token to the tracker
inner_ctx.tokens.push(step.token_id);
// Decode the token // Decode the token
let text = inner_ctx let text = inner_ctx
@ -235,14 +239,13 @@ impl TensorRtLlmBackend {
special, special,
}; };
let out = if step.is_final { if step.is_final {
inner_ctx.done.store(true, Ordering::Relaxed);
let generated_text = inner_ctx let generated_text = inner_ctx
.tokenizer .tokenizer
.decode(&inner_ctx.tokens, true) .decode(&inner_ctx.tokens, true)
.expect("Failed to decode generated_tokens"); .expect("Failed to decode generated_tokens");
InferStreamResponse::End { Ok(InferStreamResponse::End {
token, token,
top_tokens: vec![], top_tokens: vec![],
generated_text: GeneratedText { generated_text: GeneratedText {
@ -253,21 +256,26 @@ impl TensorRtLlmBackend {
}, },
start: inner_ctx.start.unwrap_or(Instant::now()), start: inner_ctx.start.unwrap_or(Instant::now()),
queued: inner_ctx.queued, queued: inner_ctx.queued,
} })
} else { } else {
InferStreamResponse::Intermediate { Ok(InferStreamResponse::Intermediate {
token, token,
top_tokens: vec![], top_tokens: vec![],
})
} }
} else {
Err(InferError::GenerationError(step.error_msg))
}; };
// Send the parcel to the client
inner_ctx inner_ctx
.sender .sender
.send(Ok(out)) .send(parcel)
.expect("Failed to send back generated token"); .expect("Failed to sent msg through the channel");
}, },
); );
debug!("Releasing write lock stream")
} }
debug!("Releasing write lock stream");
}) })
.await; .await;
} }

View File

@ -54,12 +54,17 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
++numTokens; ++numTokens;
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
step = huggingface::tgi::backends::GenerationStep{static_cast<uint32_t>(token), logProb, isFinal}; step = huggingface::tgi::backends::GenerationStep{
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
};
SPDLOG_DEBUG("\tStreamTokens -> Post callback"); SPDLOG_DEBUG("\tStreamTokens -> Post callback");
} else { } else {
// TODO : Return rest::Result with error // TODO : Return rest::Result with error
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg()); const auto what = item.getErrorMsg();
step = huggingface::tgi::backends::GenerationStep{std::numeric_limits<uint32_t>::max(), 0.0, true}; SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
step = huggingface::tgi::backends::GenerationStep{
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
};
} }
callback(std::move(ctx), std::move(step)); callback(std::move(ctx), std::move(step));

View File

@ -8,11 +8,12 @@ mod ffi {
/// Struct used as shared type between rust and C++ to represent the result /// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration /// of a single decoding iteration
#[derive(Copy, Clone)]
pub struct GenerationStep { pub struct GenerationStep {
token_id: u32, token_id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
has_error: bool,
error_msg: String,
} }
extern "Rust" { extern "Rust" {