From 1473259f84fb0272b357392a13eaa168d39bc1c4 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Mon, 4 Nov 2024 17:01:22 +0100 Subject: [PATCH] feat(backend): add early stopping criteria from TGI stream callback --- backends/llamacpp/csrc/backend.cpp | 16 +++++++++------- backends/llamacpp/csrc/backend.hpp | 4 ++-- backends/llamacpp/csrc/ffi.hpp | 6 +++--- backends/llamacpp/src/backend.rs | 13 ++++++++----- backends/llamacpp/src/lib.rs | 2 +- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/backends/llamacpp/csrc/backend.cpp b/backends/llamacpp/csrc/backend.cpp index 65898dfe..f6956381 100644 --- a/backends/llamacpp/csrc/backend.cpp +++ b/backends/llamacpp/csrc/backend.cpp @@ -121,11 +121,12 @@ namespace huggingface::tgi::backends::llamacpp { generating = !(has_reach_max_tokens | has_reach_eog); // Bubble up the generated token if a callback is provided - std::invoke(std::forward(callback_), - new_token_id, - new_token_logits, - !generating, - n_decoded_tokens + 1); + const auto should_stop = std::invoke(std::forward(callback_), + new_token_id, + new_token_logits, + !generating, + n_decoded_tokens + 1); + generating ^= should_stop; batch = llama_batch_get_one(&new_token_id, 1); } @@ -148,11 +149,12 @@ namespace huggingface::tgi::backends::llamacpp { // TODO: Should we provide a way to change this value? auto generated = std::vector(2 << 8); auto inner_callback = [&](uint32_t new_token_id, float_t new_token_logit, bool is_eos, - size_t num_generated_tokens) { + size_t num_generated_tokens) -> bool { generated.emplace_back(new_token_id); if (callback.has_value()) - (*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens); + return (*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens); + return true; }; auto nTokensGenerated = stream(tokens, generation_params, sampling_params, inner_callback); diff --git a/backends/llamacpp/csrc/backend.hpp b/backends/llamacpp/csrc/backend.hpp index 1fef7fb8..bf9df5cc 100644 --- a/backends/llamacpp/csrc/backend.hpp +++ b/backends/llamacpp/csrc/backend.hpp @@ -29,8 +29,8 @@ namespace huggingface::tgi::backends::llamacpp { static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); }; typedef std::unique_ptr llama_sampler_ptr; - typedef std::function llama_decode_callback; - static constexpr auto llama_void_callback = [](llama_token, float_t, bool, size_t) {}; + typedef std::function llama_decode_callback; + static constexpr auto llama_void_callback = [](llama_token, float_t, bool, size_t) -> bool { return false; }; /** * diff --git a/backends/llamacpp/csrc/ffi.hpp b/backends/llamacpp/csrc/ffi.hpp index 3ae392f6..f33a2f1a 100644 --- a/backends/llamacpp/csrc/ffi.hpp +++ b/backends/llamacpp/csrc/ffi.hpp @@ -64,14 +64,14 @@ namespace huggingface::tgi::backends::llamacpp { const generation_params_t generation_params, const sampling_params_t &sampling_params, InferContext *ctx, - rust::Fn callback + rust::Fn callback ) { // Define the visitor lambda function which requires the has_emplace_generate constraint on T auto inner_fw = [=, &sampling_params, &ctx, &callback](T &&backend) -> std::expected { - auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){ - callback(ctx, new_token_id, logits, is_eos, n_generated_tokens); + auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool { + return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens); }; // Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t* diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 06e8d43e..531a07dc 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -13,11 +13,10 @@ use text_generation_router::validation::{ }; use text_generation_router::{FinishReason, Token}; use thiserror::Error; -use tokio::sync::mpsc::error::SendError; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, error, info}; +use tracing::{error, info}; type InferResult = Result; @@ -45,7 +44,7 @@ impl From<&ValidStoppingParameters> for GenerationParams { } #[cfg_attr(debug_assertions, derive(Debug))] -struct GenerationContext { +pub(crate) struct GenerationContext { pub(crate) input_tokens: Arc>, pub(crate) generated_tokens: Vec, pub(crate) generation_params: GenerationParams, @@ -108,7 +107,7 @@ fn llama_generate_callback( new_token_logit: f32, is_final: bool, n_generated_tokens: usize, -) { +) -> bool { info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})"); // Decode token @@ -151,10 +150,14 @@ fn llama_generate_callback( }; // Send back to the client - if let Err(ref err) = ctx.stream.send(Ok(response)) { + if let Err(ref _err) = ctx.stream.send(Ok(response)) { error!("Failed to send back the response to the client, cancelling request"); // TODO: cancel the request + return true; // should_stop } + + // should_stop + false } unsafe fn scheduler_loop( diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index 006c7387..abcdd1fa 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -58,7 +58,7 @@ mod ffi { generation_params: GenerationParams, sampling_params: &SamplingParams, stream: *mut InferContext, - callback: unsafe fn(*mut InferContext, u32, f32, bool, usize), + callback: unsafe fn(*mut InferContext, u32, f32, bool, usize) -> bool, ) -> Result; } }