expose information about potential error happening while decoding

This commit is contained in:
Morgan Funtowicz 2024-07-18 22:07:59 +00:00
parent a19d318947
commit e82dc30e8a
3 changed files with 67 additions and 53 deletions

View File

@ -197,77 +197,85 @@ impl TensorRtLlmBackend {
.await; .await;
while let Some(_) = generation.next().await { while let Some(_) = generation.next().await {
span!(Level::DEBUG, "decode", request_id = request_id) let mut executor_w = executor.write().await;
.in_scope(|| async { let executor = executor_w.pin_mut();
let mut executor_w = executor.write().await;
debug!("Acquired write lock stream");
span!(Level::DEBUG, "decode")
.in_scope(|| async {
unsafe { unsafe {
debug!("Acquired write lock stream"); executor.stream_tokens(
executor_w.pin_mut().stream_tokens(
request_id, request_id,
ctx_, ctx_,
|ctx: *mut GenerationContext, step: GenerationStep| { |ctx: *mut GenerationContext, step: GenerationStep| {
let inner_ctx = &mut *ctx; let inner_ctx = &mut *ctx;
// Insert the latest generated token to the tracker
inner_ctx.tokens.push(step.token_id);
// Update the timestamp at which the request started effectively // Update the timestamp at which the request started effectively
// Can be a bit off, would need to be before the callback, let's see // Can be a bit off, would need to be before the callback, let's see
inner_ctx.start.get_or_insert(Instant::now()); inner_ctx.start.get_or_insert(Instant::now());
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
// Decode the token // Ensure we are not running into errors
let text = inner_ctx let parcel = if !step.has_error {
.tokenizer // Insert the latest generated token to the tracker
.decode(&[step.token_id], true) inner_ctx.tokens.push(step.token_id);
.expect("Failed to decode token");
let special = inner_ctx // Decode the token
.tokenizer let text = inner_ctx
.get_added_vocabulary()
.is_special_token(&text);
// Create the structure holding the token
let token = Token {
id: step.token_id,
text,
logprob: step.log_prob,
special,
};
let out = if step.is_final {
inner_ctx.done.store(true, Ordering::Relaxed);
let generated_text = inner_ctx
.tokenizer .tokenizer
.decode(&inner_ctx.tokens, true) .decode(&[step.token_id], true)
.expect("Failed to decode generated_tokens"); .expect("Failed to decode token");
InferStreamResponse::End { let special = inner_ctx
token, .tokenizer
top_tokens: vec![], .get_added_vocabulary()
generated_text: GeneratedText { .is_special_token(&text);
text: generated_text,
generated_tokens: inner_ctx.tokens.len() as u32, // Create the structure holding the token
finish_reason: FinishReason::EndOfSequenceToken, let token = Token {
seed: None, id: step.token_id,
}, text,
start: inner_ctx.start.unwrap_or(Instant::now()), logprob: step.log_prob,
queued: inner_ctx.queued, special,
};
if step.is_final {
let generated_text = inner_ctx
.tokenizer
.decode(&inner_ctx.tokens, true)
.expect("Failed to decode generated_tokens");
Ok(InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: generated_text,
generated_tokens: inner_ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: inner_ctx.start.unwrap_or(Instant::now()),
queued: inner_ctx.queued,
})
} else {
Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
})
} }
} else { } else {
InferStreamResponse::Intermediate { Err(InferError::GenerationError(step.error_msg))
token,
top_tokens: vec![],
}
}; };
// Send the parcel to the client
inner_ctx inner_ctx
.sender .sender
.send(Ok(out)) .send(parcel)
.expect("Failed to send back generated token"); .expect("Failed to sent msg through the channel");
}, },
); );
debug!("Releasing write lock stream")
} }
debug!("Releasing write lock stream");
}) })
.await; .await;
} }

View File

@ -54,12 +54,17 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
++numTokens; ++numTokens;
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
step = huggingface::tgi::backends::GenerationStep{static_cast<uint32_t>(token), logProb, isFinal}; step = huggingface::tgi::backends::GenerationStep{
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
};
SPDLOG_DEBUG("\tStreamTokens -> Post callback"); SPDLOG_DEBUG("\tStreamTokens -> Post callback");
} else { } else {
// TODO : Return rest::Result with error // TODO : Return rest::Result with error
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg()); const auto what = item.getErrorMsg();
step = huggingface::tgi::backends::GenerationStep{std::numeric_limits<uint32_t>::max(), 0.0, true}; SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
step = huggingface::tgi::backends::GenerationStep{
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
};
} }
callback(std::move(ctx), std::move(step)); callback(std::move(ctx), std::move(step));

View File

@ -8,11 +8,12 @@ mod ffi {
/// Struct used as shared type between rust and C++ to represent the result /// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration /// of a single decoding iteration
#[derive(Copy, Clone)]
pub struct GenerationStep { pub struct GenerationStep {
token_id: u32, token_id: u32,
log_prob: f32, log_prob: f32,
is_final: bool, is_final: bool,
has_error: bool,
error_msg: String,
} }
extern "Rust" { extern "Rust" {