mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
feat(backend): add early stopping criteria from TGI stream callback
This commit is contained in:
parent
958c72a44a
commit
1473259f84
@ -121,11 +121,12 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
generating = !(has_reach_max_tokens | has_reach_eog);
|
generating = !(has_reach_max_tokens | has_reach_eog);
|
||||||
|
|
||||||
// Bubble up the generated token if a callback is provided
|
// Bubble up the generated token if a callback is provided
|
||||||
std::invoke(std::forward<const llama_decode_callback>(callback_),
|
const auto should_stop = std::invoke(std::forward<const llama_decode_callback>(callback_),
|
||||||
new_token_id,
|
new_token_id,
|
||||||
new_token_logits,
|
new_token_logits,
|
||||||
!generating,
|
!generating,
|
||||||
n_decoded_tokens + 1);
|
n_decoded_tokens + 1);
|
||||||
|
generating ^= should_stop;
|
||||||
|
|
||||||
batch = llama_batch_get_one(&new_token_id, 1);
|
batch = llama_batch_get_one(&new_token_id, 1);
|
||||||
}
|
}
|
||||||
@ -148,11 +149,12 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
// TODO: Should we provide a way to change this value?
|
// TODO: Should we provide a way to change this value?
|
||||||
auto generated = std::vector<llama_token>(2 << 8);
|
auto generated = std::vector<llama_token>(2 << 8);
|
||||||
auto inner_callback = [&](uint32_t new_token_id, float_t new_token_logit, bool is_eos,
|
auto inner_callback = [&](uint32_t new_token_id, float_t new_token_logit, bool is_eos,
|
||||||
size_t num_generated_tokens) {
|
size_t num_generated_tokens) -> bool {
|
||||||
generated.emplace_back(new_token_id);
|
generated.emplace_back(new_token_id);
|
||||||
|
|
||||||
if (callback.has_value())
|
if (callback.has_value())
|
||||||
(*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens);
|
return (*callback)(new_token_id, new_token_logit, is_eos, num_generated_tokens);
|
||||||
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto nTokensGenerated = stream(tokens, generation_params, sampling_params, inner_callback);
|
auto nTokensGenerated = stream(tokens, generation_params, sampling_params, inner_callback);
|
||||||
|
@ -29,8 +29,8 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); };
|
static constexpr auto llama_sampler_deleter = [](llama_sampler *pSampler) { llama_sampler_free(pSampler); };
|
||||||
typedef std::unique_ptr<llama_sampler, decltype(llama_sampler_deleter)> llama_sampler_ptr;
|
typedef std::unique_ptr<llama_sampler, decltype(llama_sampler_deleter)> llama_sampler_ptr;
|
||||||
|
|
||||||
typedef std::function<void(llama_token, float_t, bool, size_t)> llama_decode_callback;
|
typedef std::function<bool(llama_token, float_t, bool, size_t)> llama_decode_callback;
|
||||||
static constexpr auto llama_void_callback = [](llama_token, float_t, bool, size_t) {};
|
static constexpr auto llama_void_callback = [](llama_token, float_t, bool, size_t) -> bool { return false; };
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -64,14 +64,14 @@ namespace huggingface::tgi::backends::llamacpp {
|
|||||||
const generation_params_t generation_params,
|
const generation_params_t generation_params,
|
||||||
const sampling_params_t &sampling_params,
|
const sampling_params_t &sampling_params,
|
||||||
InferContext *ctx,
|
InferContext *ctx,
|
||||||
rust::Fn<void(InferContext *, uint32_t, float_t, bool, size_t)> callback
|
rust::Fn<bool(InferContext *, uint32_t, float_t, bool, size_t)> callback
|
||||||
) {
|
) {
|
||||||
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
|
// Define the visitor lambda function which requires the has_emplace_generate constraint on T
|
||||||
auto inner_fw = [=, &sampling_params, &ctx, &callback]<has_stream_method T>(T &&backend)
|
auto inner_fw = [=, &sampling_params, &ctx, &callback]<has_stream_method T>(T &&backend)
|
||||||
-> std::expected<size_t, backend_error_t> {
|
-> std::expected<size_t, backend_error_t> {
|
||||||
|
|
||||||
auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens){
|
auto context_forwarding_callback = [=, &ctx](uint32_t new_token_id, float_t logits, bool is_eos, size_t n_generated_tokens) -> bool {
|
||||||
callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
|
return callback(ctx, new_token_id, logits, is_eos, n_generated_tokens);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
|
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
|
||||||
|
@ -13,11 +13,10 @@ use text_generation_router::validation::{
|
|||||||
};
|
};
|
||||||
use text_generation_router::{FinishReason, Token};
|
use text_generation_router::{FinishReason, Token};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
type InferResult = Result<InferStreamResponse, InferError>;
|
type InferResult = Result<InferStreamResponse, InferError>;
|
||||||
|
|
||||||
@ -45,7 +44,7 @@ impl From<&ValidStoppingParameters> for GenerationParams {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(debug_assertions, derive(Debug))]
|
#[cfg_attr(debug_assertions, derive(Debug))]
|
||||||
struct GenerationContext {
|
pub(crate) struct GenerationContext {
|
||||||
pub(crate) input_tokens: Arc<Vec<u32>>,
|
pub(crate) input_tokens: Arc<Vec<u32>>,
|
||||||
pub(crate) generated_tokens: Vec<u32>,
|
pub(crate) generated_tokens: Vec<u32>,
|
||||||
pub(crate) generation_params: GenerationParams,
|
pub(crate) generation_params: GenerationParams,
|
||||||
@ -108,7 +107,7 @@ fn llama_generate_callback(
|
|||||||
new_token_logit: f32,
|
new_token_logit: f32,
|
||||||
is_final: bool,
|
is_final: bool,
|
||||||
n_generated_tokens: usize,
|
n_generated_tokens: usize,
|
||||||
) {
|
) -> bool {
|
||||||
info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})");
|
info!("Generated token: {new_token_id} -> logits={new_token_logit}, is_final={is_final} ({n_generated_tokens})");
|
||||||
|
|
||||||
// Decode token
|
// Decode token
|
||||||
@ -151,10 +150,14 @@ fn llama_generate_callback(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Send back to the client
|
// Send back to the client
|
||||||
if let Err(ref err) = ctx.stream.send(Ok(response)) {
|
if let Err(ref _err) = ctx.stream.send(Ok(response)) {
|
||||||
error!("Failed to send back the response to the client, cancelling request");
|
error!("Failed to send back the response to the client, cancelling request");
|
||||||
// TODO: cancel the request
|
// TODO: cancel the request
|
||||||
|
return true; // should_stop
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// should_stop
|
||||||
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe fn scheduler_loop(
|
unsafe fn scheduler_loop(
|
||||||
|
@ -58,7 +58,7 @@ mod ffi {
|
|||||||
generation_params: GenerationParams,
|
generation_params: GenerationParams,
|
||||||
sampling_params: &SamplingParams,
|
sampling_params: &SamplingParams,
|
||||||
stream: *mut InferContext,
|
stream: *mut InferContext,
|
||||||
callback: unsafe fn(*mut InferContext, u32, f32, bool, usize),
|
callback: unsafe fn(*mut InferContext, u32, f32, bool, usize) -> bool,
|
||||||
) -> Result<usize>;
|
) -> Result<usize>;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user