define a shared struct to hold the result of a decoding step

This commit is contained in:
Morgan Funtowicz 2024-07-18 21:33:04 +00:00
parent a036574a86
commit a19d318947
4 changed files with 30 additions and 21 deletions

View File

@ -60,7 +60,8 @@ namespace huggingface::tgi::backends {
size_t StreamTokens(
const RequestId requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *, uint32_t, float_t, bool)> callback);
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback);
};
/***

View File

@ -20,13 +20,11 @@ use tracing::{instrument, Level, span};
use text_generation_router::{FinishReason, Token};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{
Chunk, ValidationError, ValidGenerateRequest, ValidParameters,
};
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
use text_generation_router::validation::ValidationError::UnsupportedModality;
use crate::errors::TensorRtLlmBackendError;
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
// Value used to poll the state of the generation stream
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
@ -208,14 +206,11 @@ impl TensorRtLlmBackend {
executor_w.pin_mut().stream_tokens(
request_id,
ctx_,
|ctx: *mut GenerationContext,
token_id: u32,
logprob: f32,
is_final: bool| {
|ctx: *mut GenerationContext, step: GenerationStep| {
let inner_ctx = &mut *ctx;
// Insert the latest generated token to the tracker
inner_ctx.tokens.push(token_id);
inner_ctx.tokens.push(step.token_id);
// Update the timestamp at which the request started effectively
// Can be a bit off, would need to be before the callback, let's see
@ -224,7 +219,7 @@ impl TensorRtLlmBackend {
// Decode the token
let text = inner_ctx
.tokenizer
.decode(&[token_id], true)
.decode(&[step.token_id], true)
.expect("Failed to decode token");
let special = inner_ctx
@ -234,13 +229,13 @@ impl TensorRtLlmBackend {
// Create the structure holding the token
let token = Token {
id: token_id,
id: step.token_id,
text,
logprob,
logprob: step.log_prob,
special,
};
let out = if is_final {
let out = if step.is_final {
inner_ctx.done.store(true, Ordering::Relaxed);
let generated_text = inner_ctx
.tokenizer

View File

@ -6,6 +6,7 @@
#include <cmath>
#include <exception>
#include <filesystem>
#include <limits>
#include <iterator>
#include <vector>
@ -36,10 +37,12 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
const uint64_t requestId,
huggingface::tgi::backends::GenerationContext *ctx,
rust::Fn<void(huggingface::tgi::backends::GenerationContext *, uint32_t, float_t, bool)> callback) {
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
huggingface::tgi::backends::GenerationStep)> callback) {
size_t numTokens = 0;
for (const auto &item: Poll(requestId)) {
GenerationStep step;
if (!item.hasError()) {
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
const auto decoded = item.getResult();
@ -51,13 +54,15 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
++numTokens;
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
callback(std::move(ctx), token, logProb, isFinal);
step = huggingface::tgi::backends::GenerationStep{static_cast<uint32_t>(token), logProb, isFinal};
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
} else {
// TODO : Return rest::Result with error
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg());
callback(std::move(ctx), 0, 0.0, true);
step = huggingface::tgi::backends::GenerationStep{std::numeric_limits<uint32_t>::max(), 0.0, true};
}
callback(std::move(ctx), std::move(step));
}
return numTokens;

View File

@ -1,12 +1,20 @@
pub use backend::TensorRtLlmBackend;
use crate::backend::GenerationContext;
pub use backend::{GenerationContext, TensorRtLlmBackend};
mod backend;
pub mod errors;
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
mod ffi {
/// Struct used as shared type between rust and C++ to represent the result
/// of a single decoding iteration
#[derive(Copy, Clone)]
pub struct GenerationStep {
token_id: u32,
log_prob: f32,
is_final: bool,
}
extern "Rust" {
type GenerationContext;
}
@ -60,7 +68,7 @@ mod ffi {
self: Pin<&mut TensorRtLlmBackendImpl>,
request_id: u64,
ctx: *mut GenerationContext,
cb: unsafe fn(*mut GenerationContext, u32, f32, bool),
cb: unsafe fn(*mut GenerationContext, GenerationStep),
) -> usize;
// #[rust_name = "shutdown"]