mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
expose information about potential error happening while decoding
This commit is contained in:
parent
a19d318947
commit
e82dc30e8a
@ -197,77 +197,85 @@ impl TensorRtLlmBackend {
|
||||
.await;
|
||||
|
||||
while let Some(_) = generation.next().await {
|
||||
span!(Level::DEBUG, "decode", request_id = request_id)
|
||||
.in_scope(|| async {
|
||||
let mut executor_w = executor.write().await;
|
||||
let mut executor_w = executor.write().await;
|
||||
let executor = executor_w.pin_mut();
|
||||
|
||||
debug!("Acquired write lock stream");
|
||||
span!(Level::DEBUG, "decode")
|
||||
.in_scope(|| async {
|
||||
unsafe {
|
||||
debug!("Acquired write lock stream");
|
||||
executor_w.pin_mut().stream_tokens(
|
||||
executor.stream_tokens(
|
||||
request_id,
|
||||
ctx_,
|
||||
|ctx: *mut GenerationContext, step: GenerationStep| {
|
||||
let inner_ctx = &mut *ctx;
|
||||
|
||||
// Insert the latest generated token to the tracker
|
||||
inner_ctx.tokens.push(step.token_id);
|
||||
|
||||
// Update the timestamp at which the request started effectively
|
||||
// Can be a bit off, would need to be before the callback, let's see
|
||||
inner_ctx.start.get_or_insert(Instant::now());
|
||||
inner_ctx.done.store(step.is_final, Ordering::Relaxed);
|
||||
|
||||
// Decode the token
|
||||
let text = inner_ctx
|
||||
.tokenizer
|
||||
.decode(&[step.token_id], true)
|
||||
.expect("Failed to decode token");
|
||||
// Ensure we are not running into errors
|
||||
let parcel = if !step.has_error {
|
||||
// Insert the latest generated token to the tracker
|
||||
inner_ctx.tokens.push(step.token_id);
|
||||
|
||||
let special = inner_ctx
|
||||
.tokenizer
|
||||
.get_added_vocabulary()
|
||||
.is_special_token(&text);
|
||||
|
||||
// Create the structure holding the token
|
||||
let token = Token {
|
||||
id: step.token_id,
|
||||
text,
|
||||
logprob: step.log_prob,
|
||||
special,
|
||||
};
|
||||
|
||||
let out = if step.is_final {
|
||||
inner_ctx.done.store(true, Ordering::Relaxed);
|
||||
let generated_text = inner_ctx
|
||||
// Decode the token
|
||||
let text = inner_ctx
|
||||
.tokenizer
|
||||
.decode(&inner_ctx.tokens, true)
|
||||
.expect("Failed to decode generated_tokens");
|
||||
.decode(&[step.token_id], true)
|
||||
.expect("Failed to decode token");
|
||||
|
||||
InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: generated_text,
|
||||
generated_tokens: inner_ctx.tokens.len() as u32,
|
||||
finish_reason: FinishReason::EndOfSequenceToken,
|
||||
seed: None,
|
||||
},
|
||||
start: inner_ctx.start.unwrap_or(Instant::now()),
|
||||
queued: inner_ctx.queued,
|
||||
let special = inner_ctx
|
||||
.tokenizer
|
||||
.get_added_vocabulary()
|
||||
.is_special_token(&text);
|
||||
|
||||
// Create the structure holding the token
|
||||
let token = Token {
|
||||
id: step.token_id,
|
||||
text,
|
||||
logprob: step.log_prob,
|
||||
special,
|
||||
};
|
||||
|
||||
if step.is_final {
|
||||
let generated_text = inner_ctx
|
||||
.tokenizer
|
||||
.decode(&inner_ctx.tokens, true)
|
||||
.expect("Failed to decode generated_tokens");
|
||||
|
||||
Ok(InferStreamResponse::End {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
generated_text: GeneratedText {
|
||||
text: generated_text,
|
||||
generated_tokens: inner_ctx.tokens.len() as u32,
|
||||
finish_reason: FinishReason::EndOfSequenceToken,
|
||||
seed: None,
|
||||
},
|
||||
start: inner_ctx.start.unwrap_or(Instant::now()),
|
||||
queued: inner_ctx.queued,
|
||||
})
|
||||
} else {
|
||||
Ok(InferStreamResponse::Intermediate {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
})
|
||||
}
|
||||
} else {
|
||||
InferStreamResponse::Intermediate {
|
||||
token,
|
||||
top_tokens: vec![],
|
||||
}
|
||||
Err(InferError::GenerationError(step.error_msg))
|
||||
};
|
||||
|
||||
// Send the parcel to the client
|
||||
inner_ctx
|
||||
.sender
|
||||
.send(Ok(out))
|
||||
.expect("Failed to send back generated token");
|
||||
.send(parcel)
|
||||
.expect("Failed to sent msg through the channel");
|
||||
},
|
||||
);
|
||||
debug!("Releasing write lock stream")
|
||||
}
|
||||
debug!("Releasing write lock stream");
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
@ -54,12 +54,17 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
||||
++numTokens;
|
||||
|
||||
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
||||
step = huggingface::tgi::backends::GenerationStep{static_cast<uint32_t>(token), logProb, isFinal};
|
||||
step = huggingface::tgi::backends::GenerationStep{
|
||||
static_cast<uint32_t>(token), logProb, isFinal, false, std::move(std::string())
|
||||
};
|
||||
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
||||
} else {
|
||||
// TODO : Return rest::Result with error
|
||||
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg());
|
||||
step = huggingface::tgi::backends::GenerationStep{std::numeric_limits<uint32_t>::max(), 0.0, true};
|
||||
const auto what = item.getErrorMsg();
|
||||
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what);
|
||||
step = huggingface::tgi::backends::GenerationStep{
|
||||
std::numeric_limits<uint32_t>::max(), 0.0, true, true, std::move(what)
|
||||
};
|
||||
}
|
||||
|
||||
callback(std::move(ctx), std::move(step));
|
||||
|
@ -8,11 +8,12 @@ mod ffi {
|
||||
|
||||
/// Struct used as shared type between rust and C++ to represent the result
|
||||
/// of a single decoding iteration
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct GenerationStep {
|
||||
token_id: u32,
|
||||
log_prob: f32,
|
||||
is_final: bool,
|
||||
has_error: bool,
|
||||
error_msg: String,
|
||||
}
|
||||
|
||||
extern "Rust" {
|
||||
|
Loading…
Reference in New Issue
Block a user