From 5b7a951389216a58cc603c28b1c3ea8e87930bca Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 4 Nov 2024 16:17:43 +0100 Subject: [PATCH] feat(backend): refactor the callback to handle intermediate and end inference message --- backends/llamacpp/csrc/backend.cpp | 35 +++---- backends/llamacpp/csrc/backend.hpp | 56 +++++------ backends/llamacpp/csrc/ffi.hpp | 27 ++--- backends/llamacpp/src/backend.rs | 153 ++++++++++++++++------------- backends/llamacpp/src/lib.rs | 12 +-- 5 files changed, 142 insertions(+), 141 deletions(-) diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index 79c09a26..65898dfe 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -114,9 +114,6 @@ namespace huggingface::tgi::backends::llamacpp { auto is_eog = llama_token_is_eog(mModel_.get(), new_token_id); auto new_token_logits = 0.0f; // TODO: return logit - // Push the token to the generated vector on Rust side - generation_context.generated_tokens[n_decoded_tokens] = new_token_id; - // Handle termination cases const auto has_reach_max_tokens = n_decoded_tokens >= max_new_tokens - 1; const auto has_reach_eog = !generation_context.generation_params.ignore_eos_token & is_eog; @@ -150,10 +147,15 @@ namespace huggingface::tgi::backends::llamacpp { ) { // TODO: Should we provide a way to change this value? auto generated = std::vector(2 << 8); + auto inner_callback = [&](uint32_t new_token_id, float_t new_token_logit, bool is_eos, + size_t num_generated_tokens) { + generated.emplace_back(new_token_id); - auto nTokensGenerated = generate(tokens, generated, generation_params, sampling_params, callback); - if (nTokensGenerated.has_value()) - generated.resize(*nTokensGenerated); + if (callback.has_value()) + (*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens); + }; + + auto nTokensGenerated = stream(tokens, generation_params, sampling_params, inner_callback); return generated; } @@ -168,25 +170,24 @@ namespace huggingface::tgi::backends::llamacpp { llama_numa_init(ggml_numa_strategy::GGML_NUMA_STRATEGY_NUMACTL); } - std::expected - single_worker_backend_t::generate( + std::expected + single_worker_backend_t::stream( std::span tokens, - std::span out, const generation_params_t &generation_params, const sampling_params_t &sampling_params, - const std::optional &callback + const llama_decode_callback &callback ) { - return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens, out}, callback); + return mWorker_.generate(mContext_.get(), {generation_params, sampling_params, tokens}, callback); } std::expected - multi_worker_backend_t::generate( - std::span, - std::span, + multi_worker_backend_t::stream( + std::span tokens, const generation_params_t &generation_params, const sampling_params_t &sampling_params, - const std::optional &callback) { - SPDLOG_ERROR("Not implemented yet"); - return 0uz; + const llama_decode_callback &callback + ) { + SPDLOG_WARN("Not implemented for multi_worker_t"); + return 0; } } \ No newline at end of file diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index ebae7fb0..1fef7fb8 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -69,7 +69,6 @@ namespace huggingface::tgi::backends::llamacpp { generation_params_t generation_params; sampling_params_t sampling_params; std::span input_tokens; - std::span generated_tokens; }; /** @@ -125,25 +124,9 @@ namespace huggingface::tgi::backends::llamacpp { /** * * @param tokens - * @params out - * @param params - * @param maxNewTokens - * @return - */ - [[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]] - virtual std::expected generate( - std::span input_tokens, - std::span generated_tokens, - const generation_params_t &generation_params, - const sampling_params_t &sampling_params, - const std::optional &callback - ) = 0; - - /** - * - * @param tokens - * @param params - * @param maxNewTokens + * @param generation_params + * @param sampling_params + * @param callback * @return */ [[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]] @@ -153,6 +136,22 @@ namespace huggingface::tgi::backends::llamacpp { const sampling_params_t &sampling_params, const std::optional &callback = std::nullopt ); + + /** + * + * @param tokens + * @param generation_params + * @param sampling_params + * @params callback + * @return + */ + [[nodiscard("Generated tokens will be freed after this call if not assigned to an lvalue")]] + virtual std::expected stream( + std::span tokens, + const generation_params_t &generation_params, + const sampling_params_t &sampling_params, + const llama_decode_callback &callback + ) = 0; }; @@ -174,16 +173,11 @@ namespace huggingface::tgi::backends::llamacpp { public: explicit single_worker_backend_t(llama_model *pModel, const std::optional &); - using backend_base_t::generate; - - std::expected - generate( + std::expected stream( std::span tokens, - std::span out, const generation_params_t &generation_params, const sampling_params_t &sampling_params, - const std::optional &callback - ) override; + const llama_decode_callback &callback) override; }; class multi_worker_backend_t : backend_base_t { @@ -191,13 +185,11 @@ namespace huggingface::tgi::backends::llamacpp { llama_context_ptr mContext_; public: - std::expected generate( - std::span, - std::span, + std::expected stream( + std::span tokens, const generation_params_t &generation_params, const sampling_params_t &sampling_params, - const std::optional &callback - ) override; + const llama_decode_callback &callback) override; }; } diff --git a/backends/llamacpp/csrc/ffi.hpp b/backends/llamacpp/csrc/ffi.hpp index df924cb7..3ae392f6 100644 --- a/backends/llamacpp/csrc/ffi.hpp +++ b/backends/llamacpp/csrc/ffi.hpp @@ -28,23 +28,20 @@ namespace huggingface::tgi::backends::llamacpp { // Concept identifying types which have a .generate() -> size_t method to do in-place generation template - concept has_emplace_generate = requires( + concept has_stream_method = requires( T t, std::span input_tokens, - std::span generated_tokens, const generation_params_t &generation_params, const sampling_params_t &sampling_params, llama_decode_callback callback ) { { - t.generate(input_tokens, generated_tokens, generation_params, sampling_params, callback) + t.stream(input_tokens, generation_params, sampling_params, callback) } -> std::same_as>; }; - static_assert(has_emplace_generate, - "single_worker_backend_t doesn't meet concept is_generate_emplace_capable"); - static_assert(has_emplace_generate, - "multi_worker_backend_t doesn't meet concept is_generate_emplace_capable"); + static_assert(has_stream_method, "single_worker_backend_t doesn't meet concept has_stream_method"); + static_assert(has_stream_method, "multi_worker_backend_t doesn't meet concept has_stream_method"); class llama_cpp_backend_exception_t : std::exception { @@ -64,29 +61,25 @@ namespace huggingface::tgi::backends::llamacpp { size_t stream( rust::Slice input_tokens, - rust::Slice generated_tokens, const generation_params_t generation_params, const sampling_params_t &sampling_params, - OpaqueStream *stream, - rust::Fn callback + InferContext *ctx, + rust::Fn callback ) { // Define the visitor lambda function which requires the has_emplace_generate constraint on T - auto inner_fw = [=, &sampling_params, &stream, &callback](T &&backend) + auto inner_fw = [=, &sampling_params, &ctx, &callback](T &&backend) -> std::expected { - auto context_forwarding_callback = [=, &stream](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){ - callback(stream, new_token_id, logits, is_eos, n_generated_tokens); + auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){ + callback(ctx, new_token_id, logits, is_eos, n_generated_tokens); }; // Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t* auto input_tokens_v = std::span(reinterpret_cast(input_tokens.data()), input_tokens.size()); - auto generated_tokens_v = - std::span(reinterpret_cast(generated_tokens.data()), generated_tokens.size()); - return backend.generate( + return backend.stream( input_tokens_v, - generated_tokens_v, generation_params, sampling_params, context_forwarding_callback diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index c3fff697..06e8d43e 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -1,7 +1,6 @@ use crate::ffi::{ create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams, }; -use crate::OpaqueStream; use async_trait::async_trait; use cxx::UniquePtr; use std::path::{Path, PathBuf}; @@ -14,12 +13,13 @@ use text_generation_router::validation::{ }; use text_generation_router::{FinishReason, Token}; use thiserror::Error; +use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, info}; -type BoxedOpaqueStream = Box; +type InferResult = Result; unsafe impl Send for LlamaCppBackendImpl {} @@ -45,14 +45,19 @@ impl From<&ValidStoppingParameters> for GenerationParams { } #[cfg_attr(debug_assertions, derive(Debug))] -struct InferContext { - pub(crate) stream: UnboundedSender>, +struct GenerationContext { pub(crate) input_tokens: Arc>, pub(crate) generated_tokens: Vec, pub(crate) generation_params: GenerationParams, pub(crate) sampling_params: SamplingParams, } +pub(crate) struct InferContext { + pub(crate) start: Instant, + pub(crate) stream: UnboundedSender, + pub(crate) generation: GenerationContext, +} + #[derive(Debug, Error)] pub enum LlamaCppBackendError { #[error("Provided GGUF model path {0} doesn't exist")] @@ -63,7 +68,7 @@ pub enum LlamaCppBackendError { } pub struct LlamaCppBackend { - backlog: Sender, + backlog: Sender<(GenerationContext, UnboundedSender)>, scheduler_handle: JoinHandle<()>, } @@ -98,81 +103,96 @@ impl LlamaCppBackend { } fn llama_generate_callback( - channel: *mut OpaqueStream, + ctx: *mut InferContext, new_token_id: u32, new_token_logit: f32, - is_eos: bool, + is_final: bool, n_generated_tokens: usize, ) { - let response = InferStreamResponse::Intermediate { - token: Token { - id: new_token_id, - text: "".to_string(), - logprob: new_token_logit, - special: false, - }, - top_tokens: vec![], - }; - info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_eos={is_eos} ({n_generated_tokens})"); + info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})"); - unsafe { - if let Err(ref err) = (*channel).0.send(Ok(response)) { - error!( - "Failed to send back token to the client: {}", - err.to_string() - ); - }; + // Decode token + let token = Token { + id: new_token_id, + text: "".to_string(), + logprob: new_token_logit, + special: false, + }; + + let ctx = unsafe { &mut *ctx }; + + // Append the new token to the generated ones + ctx.generation.generated_tokens.push(new_token_id); + + // Create the streamed response + let response = match is_final { + false => InferStreamResponse::Intermediate { + token, + top_tokens: vec![], + }, + true => { + // Decode the whole text + let text = String::new(); + + // Stream end response + InferStreamResponse::End { + token, + top_tokens: vec![], + generated_text: GeneratedText { + text, + generated_tokens: n_generated_tokens as u32, + finish_reason: FinishReason::Length, + seed: Some(ctx.generation.sampling_params.seed), + }, + start: ctx.start, + queued: ctx.start, + } + } + }; + + // Send back to the client + if let Err(ref err) = ctx.stream.send(Ok(response)) { + error!("Failed to send back the response to the client, cancelling request"); + // TODO: cancel the request } } unsafe fn scheduler_loop( mut backend: UniquePtr, - backlog: Receiver, + backlog: Receiver<(GenerationContext, UnboundedSender)>, ) { + // This loop will mostly decode single token at every step, so no need to rely on parallelism + tokenizers::utils::parallelism::set_parallelism(false); + loop { - if let Ok(mut ctx) = backlog.recv() { + if let Ok((generation, stream)) = backlog.recv() { let start = Instant::now(); - let stream = BoxedOpaqueStream::new(OpaqueStream(ctx.stream)); - let stream_ptr = Box::into_raw(stream); - let result = backend.pin_mut().stream( - &ctx.input_tokens, - &mut ctx.generated_tokens, - ctx.generation_params, - &ctx.sampling_params, - stream_ptr, - llama_generate_callback, - ); + let generation_params = generation.generation_params; // copy + let sampling_params = generation.sampling_params; // copy + let input_tokens = Arc::clone(&generation.input_tokens); - // Make sure we re-keep track of the OpaqueStream box - let stream = Box::from_raw(stream_ptr); + // Creating the whole InferContext and pushing it to the heap + { + let ctx = Box::new(InferContext { + start, + stream, + generation, + }); - match result { - Ok(n_tokens) => { - unsafe { - ctx.generated_tokens.set_len(n_tokens); - } + let boxed_ctx = Box::into_raw(ctx); - let _ = stream.0.send(Ok(InferStreamResponse::End { - token: Token { - id: ctx.generated_tokens[n_tokens - 1], - text: "".to_string(), - logprob: 0.0, - special: false, - }, - top_tokens: vec![], - generated_text: GeneratedText { - text: "".to_string(), - generated_tokens: n_tokens as u32, - finish_reason: FinishReason::Length, - seed: Some(ctx.sampling_params.seed), - }, - start, - queued: start, - })); - - debug!("Generated {n_tokens} tokens -> {:?}", ctx.generated_tokens); + if let Err(e) = backend.pin_mut().stream( + &input_tokens, + generation_params, + &sampling_params, + boxed_ctx, + llama_generate_callback, + ) { + error!("Error while decoding tokens... {}", e.what()); } - Err(err) => println!("Error: {err}"), + + // Make sure we re-keep track of the OpaqueStream box + let _ = Box::from_raw(boxed_ctx); } } else { info!("IPC channel is closed, exiting the scheduler loop"); @@ -186,21 +206,20 @@ impl Backend for LlamaCppBackend { fn schedule( &self, request: ValidGenerateRequest, - ) -> Result>, InferError> { + ) -> Result, InferError> { if let Some(input_ids) = request.input_ids { let (sx, rx) = unbounded_channel(); let sampling_params = SamplingParams::from(&request.parameters); let generation_params = GenerationParams::from(&request.stopping_parameters); - let ctx = InferContext { - stream: sx, + let ctx = GenerationContext { input_tokens: Arc::clone(&input_ids), generated_tokens: Vec::with_capacity(generation_params.max_new_tokens as usize), generation_params, sampling_params, }; - match self.backlog.send(ctx) { + match self.backlog.send((ctx, sx)) { Ok(_) => Ok(UnboundedReceiverStream::new(rx)), Err(_) => Err(InferError::GenerationError( "Failed to sent the request".to_string(), diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index 277f77cb..01f2054d 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -1,6 +1,5 @@ +use crate::backend::InferContext; use crate::ffi::SamplingParams; -use text_generation_router::infer::{InferError, InferStreamResponse}; -use tokio::sync::mpsc::UnboundedSender; pub mod backend; @@ -16,8 +15,6 @@ impl Default for SamplingParams { } } -struct OpaqueStream(UnboundedSender>); - #[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")] mod ffi { #[derive(Debug, Copy, Clone)] @@ -36,7 +33,7 @@ mod ffi { } extern "Rust" { - type OpaqueStream; + type InferContext; } unsafe extern "C++" { @@ -66,11 +63,10 @@ mod ffi { unsafe fn stream( self: Pin<&mut LlamaCppBackendImpl>, tokens: &[u32], - generated: &mut [u32], generation_params: GenerationParams, sampling_params: &SamplingParams, - stream: *mut OpaqueStream, - callback: unsafe fn(*mut OpaqueStream, u32, f32, bool, usize), + stream: *mut InferContext, + callback: unsafe fn(*mut InferContext, u32, f32, bool, usize), ) -> Result; } }