expose information about potential error happening while decoding

This commit is contained in:
Morgan Funtowicz 2024-07-18 22:07:59 +00:00
parent a19d318947
commit e82dc30e8a
3 changed files with 67 additions and 53 deletions

View File

@ -197,77 +197,85 @@ impl TensorRtLlmBackend {
.await;
while let Some(_) = generation.next().await {
span!(Level::DEBUG, "decode", request_id = request_id)
.in_scope(|| async {
let mut executor_w = executor.write().await;
let mut executor_w = executor.write().await;
let executor = executor_w.pin_mut();
debug!("Acquired write lock stream");
span!(Level::DEBUG, "decode")
.in_scope(|| async {
unsafe {
debug!("Acquired write lock stream");
executor_w.pin_mut().stream_tokens(
executor.stream_tokens(
request_id,
ctx_,
|ctx: *mut GenerationContext, step: GenerationStep| {
let inner_ctx = &mut *ctx;
// Insert the latest generated token to the tracker
inner_ctx.tokens.push(step.token_id);
// Update the timestamp at which the request started effectively
// Can be a bit off, would need to be before the callback, let's see
inner_ctx.start.get_or_insert(Instant::now());
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
// Decode the token
let text = inner_ctx
.tokenizer
.decode(&[step.token_id], true)
.expect("Failed to decode token");
// Ensure we are not running into errors
let parcel = if !step.has_error {
// Insert the latest generated token to the tracker
inner_ctx.tokens.push(step.token_id);
let special = inner_ctx
.tokenizer
.get_added_vocabulary()
.is_special_token(&text);
// Create the structure holding the token
let token = Token {
id: step.token_id,
text,
logprob: step.log_prob,
special,
};
let out = if step.is_final {
inner_ctx.done.store(true, Ordering::Relaxed);
let generated_text = inner_ctx
// Decode the token
let text = inner_ctx
.tokenizer
.decode(&inner_ctx.tokens, true)
.expect("Failed to decode generated_tokens");
.decode(&[step.token_id], true)
.expect("Failed to decode token");
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: generated_text,
generated_tokens: inner_ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: inner_ctx.start.unwrap_or(Instant::now()),
queued: inner_ctx.queued,
let special = inner_ctx
.tokenizer
.get_added_vocabulary()
.is_special_token(&text);
// Create the structure holding the token
let token = Token {
id: step.token_id,
text,
logprob: step.log_prob,
special,
};
if step.is_final {
let generated_text = inner_ctx
.tokenizer
.decode(&inner_ctx.tokens, true)
.expect("Failed to decode generated_tokens");
Ok(InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text: generated_text,
generated_tokens: inner_ctx.tokens.len() as u32,
finish_reason: FinishReason::EndOfSequenceToken,
seed: None,
},
start: inner_ctx.start.unwrap_or(Instant::now()),
queued: inner_ctx.queued,
})
} else {
Ok(InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
})
}
} else {
InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
}
Err(InferError::GenerationError(step.error_msg))
};
// Send the parcel to the client
inner_ctx
.sender
.send(Ok(out))
.expect("Failed to send back generated token");
.send(parcel)
.expect("Failed to sent msg through the channel");
},
);
debug!("Releasing write lock stream")
}
debug!("Releasing write lock stream");
})
.await;
}

View File

@ -54,12 +54,17 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
++numTokens;
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
step = huggingface::tgi::backends::GenerationStep{static_cast<uint32_t>(token), logProb, isFinal};
step = huggingface::tgi::backends::GenerationStep{
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
};
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
} else {
// TODO : Return rest::Result with error
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg());
step = huggingface::tgi::backends::GenerationStep{std::numeric_limits<uint32_t>::max(), 0.0, true};
const auto what = item.getErrorMsg();
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
step = huggingface::tgi::backends::GenerationStep{
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
};
}
callback(std::move(ctx), std::move(step));

View File

@ -8,11 +8,12 @@ mod ffi {
/// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration
#[derive(Copy, Clone)]
pub struct GenerationStep {
token_id: u32,
log_prob: f32,
is_final: bool,
has_error: bool,
error_msg: String,
}
extern "Rust" {