mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
define a shared struct to hold the result of a decoding step
This commit is contained in:
parent
a036574a86
commit
a19d318947
@ -60,7 +60,8 @@ namespace huggingface::tgi::backends {
|
|||||||
size_t StreamTokens(
|
size_t StreamTokens(
|
||||||
const RequestId requestId,
|
const RequestId requestId,
|
||||||
huggingface::tgi::backends::GenerationContext *ctx,
|
huggingface::tgi::backends::GenerationContext *ctx,
|
||||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *, uint32_t, float_t, bool)> callback);
|
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||||
|
huggingface::tgi::backends::GenerationStep)> callback);
|
||||||
};
|
};
|
||||||
|
|
||||||
/***
|
/***
|
||||||
|
@ -20,13 +20,11 @@ use tracing::{instrument, Level, span};
|
|||||||
|
|
||||||
use text_generation_router::{FinishReason, Token};
|
use text_generation_router::{FinishReason, Token};
|
||||||
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
|
||||||
use text_generation_router::validation::{
|
use text_generation_router::validation::{Chunk, ValidationError, ValidGenerateRequest};
|
||||||
Chunk, ValidationError, ValidGenerateRequest, ValidParameters,
|
|
||||||
};
|
|
||||||
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
use text_generation_router::validation::ValidationError::UnsupportedModality;
|
||||||
|
|
||||||
use crate::errors::TensorRtLlmBackendError;
|
use crate::errors::TensorRtLlmBackendError;
|
||||||
use crate::ffi::{create_tensorrt_llm_backend, TensorRtLlmBackendImpl};
|
use crate::ffi::{create_tensorrt_llm_backend, GenerationStep, TensorRtLlmBackendImpl};
|
||||||
|
|
||||||
// Value used to poll the state of the generation stream
|
// Value used to poll the state of the generation stream
|
||||||
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
static POLLING_INTERVAL_US: OnceLock<u64> = OnceLock::new();
|
||||||
@ -208,14 +206,11 @@ impl TensorRtLlmBackend {
|
|||||||
executor_w.pin_mut().stream_tokens(
|
executor_w.pin_mut().stream_tokens(
|
||||||
request_id,
|
request_id,
|
||||||
ctx_,
|
ctx_,
|
||||||
|ctx: *mut GenerationContext,
|
|ctx: *mut GenerationContext, step: GenerationStep| {
|
||||||
token_id: u32,
|
|
||||||
logprob: f32,
|
|
||||||
is_final: bool| {
|
|
||||||
let inner_ctx = &mut *ctx;
|
let inner_ctx = &mut *ctx;
|
||||||
|
|
||||||
// Insert the latest generated token to the tracker
|
// Insert the latest generated token to the tracker
|
||||||
inner_ctx.tokens.push(token_id);
|
inner_ctx.tokens.push(step.token_id);
|
||||||
|
|
||||||
// Update the timestamp at which the request started effectively
|
// Update the timestamp at which the request started effectively
|
||||||
// Can be a bit off, would need to be before the callback, let's see
|
// Can be a bit off, would need to be before the callback, let's see
|
||||||
@ -224,7 +219,7 @@ impl TensorRtLlmBackend {
|
|||||||
// Decode the token
|
// Decode the token
|
||||||
let text = inner_ctx
|
let text = inner_ctx
|
||||||
.tokenizer
|
.tokenizer
|
||||||
.decode(&[token_id], true)
|
.decode(&[step.token_id], true)
|
||||||
.expect("Failed to decode token");
|
.expect("Failed to decode token");
|
||||||
|
|
||||||
let special = inner_ctx
|
let special = inner_ctx
|
||||||
@ -234,13 +229,13 @@ impl TensorRtLlmBackend {
|
|||||||
|
|
||||||
// Create the structure holding the token
|
// Create the structure holding the token
|
||||||
let token = Token {
|
let token = Token {
|
||||||
id: token_id,
|
id: step.token_id,
|
||||||
text,
|
text,
|
||||||
logprob,
|
logprob: step.log_prob,
|
||||||
special,
|
special,
|
||||||
};
|
};
|
||||||
|
|
||||||
let out = if is_final {
|
let out = if step.is_final {
|
||||||
inner_ctx.done.store(true, Ordering::Relaxed);
|
inner_ctx.done.store(true, Ordering::Relaxed);
|
||||||
let generated_text = inner_ctx
|
let generated_text = inner_ctx
|
||||||
.tokenizer
|
.tokenizer
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
|
#include <limits>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -36,10 +37,12 @@ uint64_t huggingface::tgi::backends::TensorRtLlmBackendImpl::Submit(
|
|||||||
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
||||||
const uint64_t requestId,
|
const uint64_t requestId,
|
||||||
huggingface::tgi::backends::GenerationContext *ctx,
|
huggingface::tgi::backends::GenerationContext *ctx,
|
||||||
rust::Fn<void(huggingface::tgi::backends::GenerationContext *, uint32_t, float_t, bool)> callback) {
|
rust::Fn<void(huggingface::tgi::backends::GenerationContext *,
|
||||||
|
huggingface::tgi::backends::GenerationStep)> callback) {
|
||||||
|
|
||||||
size_t numTokens = 0;
|
size_t numTokens = 0;
|
||||||
for (const auto &item: Poll(requestId)) {
|
for (const auto &item: Poll(requestId)) {
|
||||||
|
GenerationStep step;
|
||||||
if (!item.hasError()) {
|
if (!item.hasError()) {
|
||||||
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
|
SPDLOG_DEBUG("\tStreamTokens -> Decoding token...");
|
||||||
const auto decoded = item.getResult();
|
const auto decoded = item.getResult();
|
||||||
@ -51,13 +54,15 @@ size_t huggingface::tgi::backends::TensorRtLlmBackendImpl::StreamTokens(
|
|||||||
++numTokens;
|
++numTokens;
|
||||||
|
|
||||||
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
SPDLOG_DEBUG(FMT_STRING("\tStreamTokens -> {:d} {:.2f} (final = {})"), token, logProb, isFinal);
|
||||||
callback(std::move(ctx), token, logProb, isFinal);
|
step = huggingface::tgi::backends::GenerationStep{static_cast<uint32_t>(token), logProb, isFinal};
|
||||||
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
SPDLOG_DEBUG("\tStreamTokens -> Post callback");
|
||||||
} else {
|
} else {
|
||||||
// TODO : Return rest::Result with error
|
// TODO : Return rest::Result with error
|
||||||
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg());
|
SPDLOG_WARN("\tStreamTokens -> Got error while decoding: {}", item.getErrorMsg());
|
||||||
callback(std::move(ctx), 0, 0.0, true);
|
step = huggingface::tgi::backends::GenerationStep{std::numeric_limits<uint32_t>::max(), 0.0, true};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
callback(std::move(ctx), std::move(step));
|
||||||
}
|
}
|
||||||
|
|
||||||
return numTokens;
|
return numTokens;
|
||||||
|
@ -1,12 +1,20 @@
|
|||||||
pub use backend::TensorRtLlmBackend;
|
pub use backend::{GenerationContext, TensorRtLlmBackend};
|
||||||
|
|
||||||
use crate::backend::GenerationContext;
|
|
||||||
|
|
||||||
mod backend;
|
mod backend;
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
|
|
||||||
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
#[cxx::bridge(namespace = "huggingface::tgi::backends")]
|
||||||
mod ffi {
|
mod ffi {
|
||||||
|
|
||||||
|
/// Struct used as shared type between rust and C++ to represent the result
|
||||||
|
/// of a single decoding iteration
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct GenerationStep {
|
||||||
|
token_id: u32,
|
||||||
|
log_prob: f32,
|
||||||
|
is_final: bool,
|
||||||
|
}
|
||||||
|
|
||||||
extern "Rust" {
|
extern "Rust" {
|
||||||
type GenerationContext;
|
type GenerationContext;
|
||||||
}
|
}
|
||||||
@ -60,7 +68,7 @@ mod ffi {
|
|||||||
self: Pin<&mut TensorRtLlmBackendImpl>,
|
self: Pin<&mut TensorRtLlmBackendImpl>,
|
||||||
request_id: u64,
|
request_id: u64,
|
||||||
ctx: *mut GenerationContext,
|
ctx: *mut GenerationContext,
|
||||||
cb: unsafe fn(*mut GenerationContext, u32, f32, bool),
|
cb: unsafe fn(*mut GenerationContext, GenerationStep),
|
||||||
) -> usize;
|
) -> usize;
|
||||||
|
|
||||||
// #[rust_name = "shutdown"]
|
// #[rust_name = "shutdown"]
|
||||||
|
Loading…
Reference in New Issue
Block a user