feat(backend): refactor the callback to handle intermediate and end inference message

This commit is contained in:
Morgan Funtowicz 2024-11-04 16:17:43 +01:00
parent 11c593dc69
commit 5b7a951389
5 changed files with 142 additions and 141 deletions

View File

@ -114,9 +114,6 @@ namespace huggingface::tgi::backends::llamacpp {
auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id); auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id);
auto new_token_logits = 0.0f; // TODO: return logit auto new_token_logits = 0.0f; // TODO: return logit
// Push the token to the generated vector on Rust side
generation_context.generated_tokens[n_decoded_tokens] = new_token_id;
// Handle termination cases // Handle termination cases
const auto has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1; const auto has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1;
const auto has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog; const auto has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog;
@ -150,10 +147,15 @@ namespace huggingface::tgi::backends::llamacpp {
) { ) {
// TODO: Should we provide a way to change this value? // TODO: Should we provide a way to change this value?
auto generated = std::vector<llama_token>(2 << 8); auto generated = std::vector<llama_token>(2 << 8);
auto inner_callback = [&](uint32_t new_token_id, float_t new_token_logit, bool is_eos,
size_t num_generated_tokens) {
generated.emplace_back(new_token_id);
auto nTokensGenerated = generate(tokens, generated, generation_params, sampling_params, callback); if (callback.has_value())
if (nTokensGenerated.has_value()) (*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens);
generated.resize(*nTokensGenerated); };
auto nTokensGenerated = stream(tokens, generation_params, sampling_params, inner_callback);
return generated; return generated;
} }
@ -168,25 +170,24 @@ namespace huggingface::tgi::backends::llamacpp {
llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL); llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL);
} }
std::expected<std::size_t, backend_error_t> std::expected<size_t, backend_error_t>
single_worker_backend_t::generate( single_worker_backend_t::stream(
std::span<const llama_token> tokens, std::span<const llama_token> tokens,
std::span<llama_token> out,
const generation_params_t &generation_params, const generation_params_t &generation_params,
const sampling_params_t &sampling_params, const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback const llama_decode_callback &callback
) { ) {
return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens, out}, callback); return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens}, callback);
} }
std::expected<size_t, backend_error_t> std::expected<size_t, backend_error_t>
multi_worker_backend_t::generate( multi_worker_backend_t::stream(
std::span<const llama_token>, std::span<const llama_token> tokens,
std::span<llama_token>,
const generation_params_t &generation_params, const generation_params_t &generation_params,
const sampling_params_t &sampling_params, const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback) { const llama_decode_callback &callback
SPDLOG_ERROR("Not implemented yet"); ) {
return 0uz; SPDLOG_WARN("Not implemented for multi_worker_t");
return 0;
} }
} }

View File

@ -69,7 +69,6 @@ namespace huggingface::tgi::backends::llamacpp {
generation_params_t generation_params; generation_params_t generation_params;
sampling_params_t sampling_params; sampling_params_t sampling_params;
std::span<const llama_token> input_tokens; std::span<const llama_token> input_tokens;
std::span<llama_token> generated_tokens;
}; };
/** /**
@ -125,25 +124,9 @@ namespace huggingface::tgi::backends::llamacpp {
/** /**
* *
* @param tokens * @param tokens
* @params out * @param generation_params
* @param params * @param sampling_params
* @param maxNewTokens * @param callback
* @return
*/
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
virtual std::expected<size_t, backend_error_t> generate(
std::span<const llama_token> input_tokens,
std::span<llama_token> generated_tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback
) = 0;
/**
*
* @param tokens
* @param params
* @param maxNewTokens
* @return * @return
*/ */
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]] [[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
@ -153,6 +136,22 @@ namespace huggingface::tgi::backends::llamacpp {
const sampling_params_t &sampling_params, const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback = std::nullopt const std::optional<llama_decode_callback> &callback = std::nullopt
); );
/**
*
* @param tokens
* @param generation_params
* @param sampling_params
* @params callback
* @return
*/
[[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]]
virtual std::expected<size_t, backend_error_t> stream(
std::span<const llama_token> tokens,
const generation_params_t &generation_params,
const sampling_params_t &sampling_params,
const llama_decode_callback &callback
) = 0;
}; };
@ -174,16 +173,11 @@ namespace huggingface::tgi::backends::llamacpp {
public: public:
explicit single_worker_backend_t(llama_model *pModel, const std::optional<llama_context_params> &); explicit single_worker_backend_t(llama_model *pModel, const std::optional<llama_context_params> &);
using backend_base_t::generate; std::expected<size_t, backend_error_t> stream(
std::expected<size_t, backend_error_t>
generate(
std::span<const llama_token> tokens, std::span<const llama_token> tokens,
std::span<llama_token> out,
const generation_params_t &generation_params, const generation_params_t &generation_params,
const sampling_params_t &sampling_params, const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback const llama_decode_callback &callback) override;
) override;
}; };
class multi_worker_backend_t : backend_base_t { class multi_worker_backend_t : backend_base_t {
@ -191,13 +185,11 @@ namespace huggingface::tgi::backends::llamacpp {
llama_context_ptr mContext_; llama_context_ptr mContext_;
public: public:
std::expected<size_t, backend_error_t> generate( std::expected<size_t, backend_error_t> stream(
std::span<const llama_token>, std::span<const llama_token> tokens,
std::span<llama_token>,
const generation_params_t &generation_params, const generation_params_t &generation_params,
const sampling_params_t &sampling_params, const sampling_params_t &sampling_params,
const std::optional<llama_decode_callback> &callback const llama_decode_callback &callback) override;
) override;
}; };
} }

View File

@ -28,23 +28,20 @@ namespace huggingface::tgi::backends::llamacpp {
// Concept identifying types which have a .generate() -> size_t method to do in-place generation // Concept identifying types which have a .generate() -> size_t method to do in-place generation
template<typename T> template<typename T>
concept has_emplace_generate = requires( concept has_stream_method = requires(
T t, T t,
std::span<const llama_token> input_tokens, std::span<const llama_token> input_tokens,
std::span<llama_token> generated_tokens,
const generation_params_t &generation_params, const generation_params_t &generation_params,
const sampling_params_t &sampling_params, const sampling_params_t &sampling_params,
llama_decode_callback callback llama_decode_callback callback
) { ) {
{ {
t.generate(input_tokens, generated_tokens, generation_params, sampling_params, callback) t.stream(input_tokens, generation_params, sampling_params, callback)
} -> std::same_as<std::expected<size_t, backend_error_t>>; } -> std::same_as<std::expected<size_t, backend_error_t>>;
}; };
static_assert(has_emplace_generate<single_worker_backend_t>, static_assert(has_stream_method<single_worker_backend_t>, "single_worker_backend_t doesn't meet concept has_stream_method");
"single_worker_backend_t doesn't meet concept is_generate_emplace_capable"); static_assert(has_stream_method<multi_worker_backend_t>, "multi_worker_backend_t doesn't meet concept has_stream_method");
static_assert(has_emplace_generate<multi_worker_backend_t>,
"multi_worker_backend_t doesn't meet concept is_generate_emplace_capable");
class llama_cpp_backend_exception_t : std::exception { class llama_cpp_backend_exception_t : std::exception {
@ -64,29 +61,25 @@ namespace huggingface::tgi::backends::llamacpp {
size_t stream( size_t stream(
rust::Slice<const uint32_t> input_tokens, rust::Slice<const uint32_t> input_tokens,
rust::Slice <uint32_t> generated_tokens,
const generation_params_t generation_params, const generation_params_t generation_params,
const sampling_params_t &sampling_params, const sampling_params_t &sampling_params,
OpaqueStream *stream, InferContext *ctx,
rust::Fn<void(OpaqueStream *, uint32_t, float_t, bool, size_t)> callback rust::Fn<void(InferContext *, uint32_t, float_t, bool, size_t)> callback
) { ) {
// Define the visitor lambda function which requires the has_emplace_generate constraint on T // Define the visitor lambda function which requires the has_emplace_generate constraint on T
auto inner_fw = [=, &sampling_params, &stream, &callback]<has_emplace_generate T>(T &&backend) auto inner_fw = [=, &sampling_params, &ctx, &callback]<has_stream_method T>(T &&backend)
-> std::expected<size_t, backend_error_t> { -> std::expected<size_t, backend_error_t> {
auto context_forwarding_callback = [=, &stream](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){ auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){
callback(stream, new_token_id, logits, is_eos, n_generated_tokens); callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
}; };
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t* // Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
auto input_tokens_v = auto input_tokens_v =
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size()); std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
auto generated_tokens_v =
std::span(reinterpret_cast<llama_token *>(generated_tokens.data()), generated_tokens.size());
return backend.generate( return backend.stream(
input_tokens_v, input_tokens_v,
generated_tokens_v,
generation_params, generation_params,
sampling_params, sampling_params,
context_forwarding_callback context_forwarding_callback

View File

@ -1,7 +1,6 @@
use crate::ffi::{ use crate::ffi::{
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams, create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
}; };
use crate::OpaqueStream;
use async_trait::async_trait; use async_trait::async_trait;
use cxx::UniquePtr; use cxx::UniquePtr;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
@ -14,12 +13,13 @@ use text_generation_router::validation::{
}; };
use text_generation_router::{FinishReason, Token}; use text_generation_router::{FinishReason, Token};
use thiserror::Error; use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info}; use tracing::{debug, error, info};
type BoxedOpaqueStream = Box<OpaqueStream>; type InferResult = Result<InferStreamResponse, InferError>;
unsafe impl Send for LlamaCppBackendImpl {} unsafe impl Send for LlamaCppBackendImpl {}
@ -45,14 +45,19 @@ impl From<&ValidStoppingParameters> for GenerationParams {
} }
#[cfg_attr(debug_assertions, derive(Debug))] #[cfg_attr(debug_assertions, derive(Debug))]
struct InferContext { struct GenerationContext {
pub(crate) stream: UnboundedSender<Result<InferStreamResponse, InferError>>,
pub(crate) input_tokens: Arc<Vec<u32>>, pub(crate) input_tokens: Arc<Vec<u32>>,
pub(crate) generated_tokens: Vec<u32>, pub(crate) generated_tokens: Vec<u32>,
pub(crate) generation_params: GenerationParams, pub(crate) generation_params: GenerationParams,
pub(crate) sampling_params: SamplingParams, pub(crate) sampling_params: SamplingParams,
} }
pub(crate) struct InferContext {
pub(crate) start: Instant,
pub(crate) stream: UnboundedSender<InferResult>,
pub(crate) generation: GenerationContext,
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum LlamaCppBackendError { pub enum LlamaCppBackendError {
#[error("Provided GGUF model path {0} doesn't exist")] #[error("Provided GGUF model path {0} doesn't exist")]
@ -63,7 +68,7 @@ pub enum LlamaCppBackendError {
} }
pub struct LlamaCppBackend { pub struct LlamaCppBackend {
backlog: Sender<InferContext>, backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
scheduler_handle: JoinHandle<()>, scheduler_handle: JoinHandle<()>,
} }
@ -98,81 +103,96 @@ impl LlamaCppBackend {
} }
fn llama_generate_callback( fn llama_generate_callback(
channel: *mut OpaqueStream, ctx: *mut InferContext,
new_token_id: u32, new_token_id: u32,
new_token_logit: f32, new_token_logit: f32,
is_eos: bool, is_final: bool,
n_generated_tokens: usize, n_generated_tokens: usize,
) { ) {
let response = InferStreamResponse::Intermediate { info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})");
token: Token {
id: new_token_id,
text: "".to_string(),
logprob: new_token_logit,
special: false,
},
top_tokens: vec![],
};
info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_eos={is_eos} ({n_generated_tokens})");
unsafe { // Decode token
if let Err(ref err) = (*channel).0.send(Ok(response)) { let token = Token {
error!( id: new_token_id,
"Failed to send back token to the client: {}", text: "".to_string(),
err.to_string() logprob: new_token_logit,
); special: false,
}; };
let ctx = unsafe { &mut *ctx };
// Append the new token to the generated ones
ctx.generation.generated_tokens.push(new_token_id);
// Create the streamed response
let response = match is_final {
false => InferStreamResponse::Intermediate {
token,
top_tokens: vec![],
},
true => {
// Decode the whole text
let text = String::new();
// Stream end response
InferStreamResponse::End {
token,
top_tokens: vec![],
generated_text: GeneratedText {
text,
generated_tokens: n_generated_tokens as u32,
finish_reason: FinishReason::Length,
seed: Some(ctx.generation.sampling_params.seed),
},
start: ctx.start,
queued: ctx.start,
}
}
};
// Send back to the client
if let Err(ref err) = ctx.stream.send(Ok(response)) {
error!("Failed to send back the response to the client, cancelling request");
// TODO: cancel the request
} }
} }
unsafe fn scheduler_loop( unsafe fn scheduler_loop(
mut backend: UniquePtr<LlamaCppBackendImpl>, mut backend: UniquePtr<LlamaCppBackendImpl>,
backlog: Receiver<InferContext>, backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
) { ) {
// This loop will mostly decode single token at every step, so no need to rely on parallelism
tokenizers::utils::parallelism::set_parallelism(false);
loop { loop {
if let Ok(mut ctx) = backlog.recv() { if let Ok((generation, stream)) = backlog.recv() {
let start = Instant::now(); let start = Instant::now();
let stream = BoxedOpaqueStream::new(OpaqueStream(ctx.stream)); let generation_params = generation.generation_params; // copy
let stream_ptr = Box::into_raw(stream); let sampling_params = generation.sampling_params; // copy
let result = backend.pin_mut().stream( let input_tokens = Arc::clone(&generation.input_tokens);
&ctx.input_tokens,
&mut ctx.generated_tokens,
ctx.generation_params,
&ctx.sampling_params,
stream_ptr,
llama_generate_callback,
);
// Make sure we re-keep track of the OpaqueStream box // Creating the whole InferContext and pushing it to the heap
let stream = Box::from_raw(stream_ptr); {
let ctx = Box::new(InferContext {
start,
stream,
generation,
});
match result { let boxed_ctx = Box::into_raw(ctx);
Ok(n_tokens) => {
unsafe {
ctx.generated_tokens.set_len(n_tokens);
}
let _ = stream.0.send(Ok(InferStreamResponse::End { if let Err(e) = backend.pin_mut().stream(
token: Token { &input_tokens,
id: ctx.generated_tokens[n_tokens - 1], generation_params,
text: "".to_string(), &sampling_params,
logprob: 0.0, boxed_ctx,
special: false, llama_generate_callback,
}, ) {
top_tokens: vec![], error!("Error while decoding tokens... {}", e.what());
generated_text: GeneratedText {
text: "".to_string(),
generated_tokens: n_tokens as u32,
finish_reason: FinishReason::Length,
seed: Some(ctx.sampling_params.seed),
},
start,
queued: start,
}));
debug!("Generated {n_tokens} tokens -> {:?}", ctx.generated_tokens);
} }
Err(err) => println!("Error: {err}"),
// Make sure we re-keep track of the OpaqueStream box
let _ = Box::from_raw(boxed_ctx);
} }
} else { } else {
info!("IPC channel is closed, exiting the scheduler loop"); info!("IPC channel is closed, exiting the scheduler loop");
@ -186,21 +206,20 @@ impl Backend for LlamaCppBackend {
fn schedule( fn schedule(
&self, &self,
request: ValidGenerateRequest, request: ValidGenerateRequest,
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> { ) -> Result<UnboundedReceiverStream<InferResult>, InferError> {
if let Some(input_ids) = request.input_ids { if let Some(input_ids) = request.input_ids {
let (sx, rx) = unbounded_channel(); let (sx, rx) = unbounded_channel();
let sampling_params = SamplingParams::from(&request.parameters); let sampling_params = SamplingParams::from(&request.parameters);
let generation_params = GenerationParams::from(&request.stopping_parameters); let generation_params = GenerationParams::from(&request.stopping_parameters);
let ctx = InferContext { let ctx = GenerationContext {
stream: sx,
input_tokens: Arc::clone(&input_ids), input_tokens: Arc::clone(&input_ids),
generated_tokens: Vec::with_capacity(generation_params.max_new_tokens as usize), generated_tokens: Vec::with_capacity(generation_params.max_new_tokens as usize),
generation_params, generation_params,
sampling_params, sampling_params,
}; };
match self.backlog.send(ctx) { match self.backlog.send((ctx, sx)) {
Ok(_) => Ok(UnboundedReceiverStream::new(rx)), Ok(_) => Ok(UnboundedReceiverStream::new(rx)),
Err(_) => Err(InferError::GenerationError( Err(_) => Err(InferError::GenerationError(
"Failed to sent the request".to_string(), "Failed to sent the request".to_string(),

View File

@ -1,6 +1,5 @@
use crate::backend::InferContext;
use crate::ffi::SamplingParams; use crate::ffi::SamplingParams;
use text_generation_router::infer::{InferError, InferStreamResponse};
use tokio::sync::mpsc::UnboundedSender;
pub mod backend; pub mod backend;
@ -16,8 +15,6 @@ impl Default for SamplingParams {
} }
} }
struct OpaqueStream(UnboundedSender<Result<InferStreamResponse, InferError>>);
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")] #[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
mod ffi { mod ffi {
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
@ -36,7 +33,7 @@ mod ffi {
} }
extern "Rust" { extern "Rust" {
type OpaqueStream; type InferContext;
} }
unsafe extern "C++" { unsafe extern "C++" {
@ -66,11 +63,10 @@ mod ffi {
unsafe fn stream( unsafe fn stream(
self: Pin<&mut LlamaCppBackendImpl>, self: Pin<&mut LlamaCppBackendImpl>,
tokens: &[u32], tokens: &[u32],
generated: &mut [u32],
generation_params: GenerationParams, generation_params: GenerationParams,
sampling_params: &SamplingParams, sampling_params: &SamplingParams,
stream: *mut OpaqueStream, stream: *mut InferContext,
callback: unsafe fn(*mut OpaqueStream, u32, f32, bool, usize), callback: unsafe fn(*mut InferContext, u32, f32, bool, usize),
) -> Result<usize>; ) -> Result<usize>;
} }
} }