mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
define a shared struct to hold the result of a decoding step
This commit is contained in:
parent
a036574a86
commit
a19d318947
@ -60,7 +60,8 @@ namespace huggingface::tgi::backends {
|
||||
size_t StreamTokens(
|
||||
const RequestId requestId,
|
||||
huggingface::tgi::backends::GenerationContext *ctx,
|
||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *, uint32_t, float_t, bool)> callback);
|
||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||
huggingface::tgi::backends::GenerationStep)> callback);
|
||||
};
|
||||
|
||||
/***
|
||||
|
@ -20,13 +20,11 @@ use tracing::{instrument, Level, span};
|
||||
|
||||
use text_generation_router::{FinishReason, Token};
|
||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||
use text_generation_router::validation::{
|
||||
Chunk, ValidationError, ValidGenerateRequest, ValidParameters,
|
||||
};
|
||||
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
|
||||
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||
|
||||
use crate::errors::TensorRtLlmBackendError;
|
||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
||||
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||
|
||||
// Value used to poll the state of the generation stream
|
||||
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
||||
@ -208,14 +206,11 @@ impl TensorRtLlmBackend {
|
||||
executor_w.pin_mut().stream_tokens(
|
||||
request_id,
|
||||
ctx_,
|
||||
|ctx: *mut GenerationContext,
|
||||
token_id: u32,
|
||||
logprob: f32,
|
||||
is_final: bool| {
|
||||
|ctx: *mut GenerationContext, step: GenerationStep| {
|
||||
let inner_ctx = &mut *ctx;
|
||||
|
||||
// Insert the latest generated token to the tracker
|
||||
inner_ctx.tokens.push(token_id);
|
||||
inner_ctx.tokens.push(step.token_id);
|
||||
|
||||
// Update the timestamp at which the request started effectively
|
||||
// Can be a bit off, would need to be before the callback, let's see
|
||||
@ -224,7 +219,7 @@ impl TensorRtLlmBackend {
|
||||
// Decode the token
|
||||
let text = inner_ctx
|
||||
.tokenizer
|
||||
.decode(&[token_id], true)
|
||||
.decode(&[step.token_id], true)
|
||||
.expect("Failed to decode token");
|
||||
|
||||
let special = inner_ctx
|
||||
@ -234,13 +229,13 @@ impl TensorRtLlmBackend {
|
||||
|
||||
// Create the structure holding the token
|
||||
let token = Token {
|
||||
id: token_id,
|
||||
id: step.token_id,
|
||||
text,
|
||||
logprob,
|
||||
logprob: step.log_prob,
|
||||
special,
|
||||
};
|
||||
|
||||
let out = if is_final {
|
||||
let out = if step.is_final {
|
||||
inner_ctx.done.store(true, Ordering::Relaxed);
|
||||
let generated_text = inner_ctx
|
||||
.tokenizer
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <cmath>
|
||||
#include <exception>
|
||||
#include <filesystem>
|
||||
#include <limits>
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
|
||||
@ -36,10 +37,12 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
||||
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
||||
const uint64_t requestId,
|
||||
huggingface::tgi::backends::GenerationContext *ctx,
|
||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *, uint32_t, float_t, bool)> callback) {
|
||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||
huggingface::tgi::backends::GenerationStep)> callback) {
|
||||
|
||||
size_t numTokens = 0;
|
||||
for (const auto &item: Poll(requestId)) {
|
||||
GenerationStep step;
|
||||
if (!item.hasError()) {
|
||||
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
|
||||
const auto decoded = item.getResult();
|
||||
@ -51,13 +54,15 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
||||
++numTokens;
|
||||
|
||||
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
||||
callback(std::move(ctx), token, logProb, isFinal);
|
||||
step = huggingface::tgi::backends::GenerationStep{static_cast<uint32_t>(token), logProb, isFinal};
|
||||
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
||||
} else {
|
||||
// TODO : Return rest::Result with error
|
||||
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg());
|
||||
callback(std::move(ctx), 0, 0.0, true);
|
||||
step = huggingface::tgi::backends::GenerationStep{std::numeric_limits<uint32_t>::max(), 0.0, true};
|
||||
}
|
||||
|
||||
callback(std::move(ctx), std::move(step));
|
||||
}
|
||||
|
||||
return numTokens;
|
||||
|
@ -1,12 +1,20 @@
|
||||
pub use backend::TensorRtLlmBackend;
|
||||
|
||||
use crate::backend::GenerationContext;
|
||||
pub use backend::{GenerationContext, TensorRtLlmBackend};
|
||||
|
||||
mod backend;
|
||||
pub mod errors;
|
||||
|
||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||
mod ffi {
|
||||
|
||||
/// Struct used as shared type between rust and C++ to represent the result
|
||||
/// of a single decoding iteration
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct GenerationStep {
|
||||
token_id: u32,
|
||||
log_prob: f32,
|
||||
is_final: bool,
|
||||
}
|
||||
|
||||
extern "Rust" {
|
||||
type GenerationContext;
|
||||
}
|
||||
@ -60,7 +68,7 @@ mod ffi {
|
||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||
request_id: u64,
|
||||
ctx: *mut GenerationContext,
|
||||
cb: unsafe fn(*mut GenerationContext, u32, f32, bool),
|
||||
cb: unsafe fn(*mut GenerationContext, GenerationStep),
|
||||
) -> usize;
|
||||
|
||||
// #[rust_name = "shutdown"]
|
||||
|
Loading…
Reference in New Issue
Block a user