diff --git a/backends/trtllm/src/backend.rs b/backends/trtllm/src/backend.rs index 1c4878e1..ac12e42e 100644 --- a/backends/trtllm/src/backend.rs +++ b/backends/trtllm/src/backend.rs @@ -197,77 +197,85 @@ impl TensorRtLlmBackend { .await; while let Some(_) = generation.next().await { - span!(Level::DEBUG, "decode", request_id = request_id) - .in_scope(|| async { - let mut executor_w = executor.write().await; + let mut executor_w = executor.write().await; + let executor = executor_w.pin_mut(); + debug!("Acquired write lock stream"); + span!(Level::DEBUG, "decode") + .in_scope(|| async { unsafe { - debug!("Acquired write lock stream"); - executor_w.pin_mut().stream_tokens( + executor.stream_tokens( request_id, ctx_, |ctx: *mut GenerationContext, step: GenerationStep| { let inner_ctx = &mut *ctx; - // Insert the latest generated token to the tracker - inner_ctx.tokens.push(step.token_id); - // Update the timestamp at which the request started effectively // Can be a bit off, would need to be before the callback, let's see inner_ctx.start.get_or_insert(Instant::now()); + inner_ctx.done.store(step.is_final, Ordering::Relaxed); - // Decode the token - let text = inner_ctx - .tokenizer - .decode(&[step.token_id], true) - .expect("Failed to decode token"); + // Ensure we are not running into errors + let parcel = if !step.has_error { + // Insert the latest generated token to the tracker + inner_ctx.tokens.push(step.token_id); - let special = inner_ctx - .tokenizer - .get_added_vocabulary() - .is_special_token(&text); - - // Create the structure holding the token - let token = Token { - id: step.token_id, - text, - logprob: step.log_prob, - special, - }; - - let out = if step.is_final { - inner_ctx.done.store(true, Ordering::Relaxed); - let generated_text = inner_ctx + // Decode the token + let text = inner_ctx .tokenizer - .decode(&inner_ctx.tokens, true) - .expect("Failed to decode generated_tokens"); + .decode(&[step.token_id], true) + .expect("Failed to decode token"); - InferStreamResponse::End { - token, - top_tokens: vec![], - generated_text: GeneratedText { - text: generated_text, - generated_tokens: inner_ctx.tokens.len() as u32, - finish_reason: FinishReason::EndOfSequenceToken, - seed: None, - }, - start: inner_ctx.start.unwrap_or(Instant::now()), - queued: inner_ctx.queued, + let special = inner_ctx + .tokenizer + .get_added_vocabulary() + .is_special_token(&text); + + // Create the structure holding the token + let token = Token { + id: step.token_id, + text, + logprob: step.log_prob, + special, + }; + + if step.is_final { + let generated_text = inner_ctx + .tokenizer + .decode(&inner_ctx.tokens, true) + .expect("Failed to decode generated_tokens"); + + Ok(InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text: GeneratedText { + text: generated_text, + generated_tokens: inner_ctx.tokens.len() as u32, + finish_reason: FinishReason::EndOfSequenceToken, + seed: None, + }, + start: inner_ctx.start.unwrap_or(Instant::now()), + queued: inner_ctx.queued, + }) + } else { + Ok(InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + }) } } else { - InferStreamResponse::Intermediate { - token, - top_tokens: vec![], - } + Err(InferError::GenerationError(step.error_msg)) }; + + // Send the parcel to the client inner_ctx .sender - .send(Ok(out)) - .expect("Failed to send back generated token"); + .send(parcel) + .expect("Failed to sent msg through the channel"); }, ); - debug!("Releasing write lock stream") } + debug!("Releasing write lock stream"); }) .await; } diff --git a/backends/trtllm/src/ffi.cpp b/backends/trtllm/src/ffi.cpp index 017d0121..d6317a68 100644 --- a/backends/trtllm/src/ffi.cpp +++ b/backends/trtllm/src/ffi.cpp @@ -54,12 +54,17 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens( ++numTokens; SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal); - step = huggingface::tgi::backends::GenerationStep{static_cast(token), logProb, isFinal}; + step = huggingface::tgi::backends::GenerationStep{ + static_cast(token), logProb, isFinal, false, std::move(std::string()) + }; SPDLOG_DEBUG("\tStreamTokens -> Post callback"); } else { // TODO : Return rest::Result with error - SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg()); - step = huggingface::tgi::backends::GenerationStep{std::numeric_limits::max(), 0.0, true}; + const auto what = item.getErrorMsg(); + SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", what); + step = huggingface::tgi::backends::GenerationStep{ + std::numeric_limits::max(), 0.0, true, true, std::move(what) + }; } callback(std::move(ctx), std::move(step)); diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 4b1ff751..1a804f88 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -8,11 +8,12 @@ mod ffi { /// Struct used as shared type between rust and C++ to represent the result /// of a single decoding iteration - #[derive(Copy, Clone)] pub struct GenerationStep { token_id: u32, log_prob: f32, is_final: bool, + has_error: bool, + error_msg: String, } extern "Rust" {