feat(backend): avoid dropping the boxed stream at the end of the callback

This commit is contained in:
Morgan Funtowicz 2024-11-03 00:36:32 +01:00
parent 612f2f939f
commit b50dcddbb8
3 changed files with 84 additions and 40 deletions

View File

@ -21,6 +21,7 @@ namespace huggingface::tgi::backends::llamacpp {
#include "backends/llamacpp/src/lib.rs.h" #include "backends/llamacpp/src/lib.rs.h"
#include "rust/cxx.h"
namespace huggingface::tgi::backends::llamacpp { namespace huggingface::tgi::backends::llamacpp {
@ -61,17 +62,22 @@ namespace huggingface::tgi::backends::llamacpp {
explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {} explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {}
size_t generate( size_t stream(
rust::Slice<const uint32_t> input_tokens, rust::Slice<const uint32_t> input_tokens,
rust::Slice <uint32_t> generated_tokens, rust::Slice <uint32_t> generated_tokens,
const generation_params_t &generation_params, const generation_params_t generation_params,
const sampling_params_t &sampling_params, const sampling_params_t &sampling_params,
rust::Fn<void(uint32_t, float_t, bool)> callback OpaqueStream *stream,
rust::Fn<void(OpaqueStream *, uint32_t, float_t, bool)> callback
) { ) {
// Define the visitor lambda function which requires the has_emplace_generate constraint on T // Define the visitor lambda function which requires the has_emplace_generate constraint on T
static auto inner_fw = [=, &generation_params, &sampling_params]<has_emplace_generate T>(T &&backend) static auto inner_fw = [=, &sampling_params, &stream, &callback]<has_emplace_generate T>(T &&backend)
-> std::expected<size_t, backend_error_t> { -> std::expected<size_t, backend_error_t> {
auto context_forwarding_callback = [=, &stream](uint32_t new_token_id, float_t logits, bool is_eos){
callback(stream, new_token_id, logits, is_eos);
};
// Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t* // Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t*
auto input_tokens_v = auto input_tokens_v =
std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size()); std::span(reinterpret_cast<const llama_token *>(input_tokens.data()), input_tokens.size());
@ -79,7 +85,12 @@ namespace huggingface::tgi::backends::llamacpp {
std::span(reinterpret_cast<llama_token *>(generated_tokens.data()), generated_tokens.size()); std::span(reinterpret_cast<llama_token *>(generated_tokens.data()), generated_tokens.size());
return backend.generate( return backend.generate(
input_tokens_v, generated_tokens_v, generation_params, sampling_params, callback); input_tokens_v,
generated_tokens_v,
generation_params,
sampling_params,
context_forwarding_callback
);
}; };
if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) { if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) {

View File

@ -1,23 +1,27 @@
use crate::ffi::{ use crate::ffi::{
create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams, create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams,
}; };
use crate::OpaqueStream;
use async_trait::async_trait; use async_trait::async_trait;
use cxx::UniquePtr; use cxx::UniquePtr;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::mpsc::{channel, Receiver, SendError, Sender}; use std::sync::mpsc::{channel, Receiver, SendError, Sender};
use std::sync::Arc; use std::sync::Arc;
use std::thread::{spawn, JoinHandle}; use std::thread::{spawn, JoinHandle};
use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
use text_generation_router::validation::{ use text_generation_router::validation::{
ValidGenerateRequest, ValidParameters, ValidStoppingParameters, ValidGenerateRequest, ValidParameters, ValidStoppingParameters,
}; };
use text_generation_router::Token; use text_generation_router::{FinishReason, Token};
use thiserror::Error; use thiserror::Error;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::TryAcquireError; use tokio::sync::TryAcquireError;
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{error, info}; use tracing::{error, info};
type BoxedOpaqueStream = Box<OpaqueStream>;
unsafe impl Send for LlamaCppBackendImpl {} unsafe impl Send for LlamaCppBackendImpl {}
impl From<&ValidParameters> for SamplingParams { impl From<&ValidParameters> for SamplingParams {
@ -86,7 +90,7 @@ impl LlamaCppBackend {
); );
let (submitter, receiver) = channel(); let (submitter, receiver) = channel();
let handle = spawn(|| scheduler_loop(backend, receiver)); let handle = unsafe { spawn(|| scheduler_loop(backend, receiver)) };
Ok(Self { Ok(Self {
backlog: submitter, backlog: submitter,
scheduler_handle: handle, scheduler_handle: handle,
@ -94,47 +98,59 @@ impl LlamaCppBackend {
} }
} }
fn scheduler_loop( fn llama_generate_callback(
channel: *mut OpaqueStream,
new_token_id: u32,
new_token_logit: f32,
is_eos: bool,
) {
let response = InferStreamResponse::Intermediate {
token: Token {
id: new_token_id,
text: "".to_string(),
logprob: new_token_logit,
special: false,
},
top_tokens: vec![],
};
println!("Generated token: {new_token_id} -> logits={new_token_logit}, is_eos={is_eos}");
unsafe {
if let Err(ref err) = (*channel).0.send(Ok(response)) {
error!(
"Failed to send back token to the client: {}",
err.to_string()
);
}
}
}
unsafe fn scheduler_loop(
mut backend: UniquePtr<LlamaCppBackendImpl>, mut backend: UniquePtr<LlamaCppBackendImpl>,
mut backlog: Receiver<InferContext>, mut backlog: Receiver<InferContext>,
) { ) {
loop { loop {
println!("Looping");
if let Ok(mut ctx) = backlog.recv() { if let Ok(mut ctx) = backlog.recv() {
println!("{ctx:?}, {}", &ctx.generated_tokens.capacity()); let stream = BoxedOpaqueStream::new(OpaqueStream(ctx.stream));
match backend.pin_mut().generate( let stream_ptr = Box::into_raw(stream);
let result = backend.pin_mut().stream(
&ctx.input_tokens, &ctx.input_tokens,
&mut ctx.generated_tokens, &mut ctx.generated_tokens,
&ctx.generation_params, ctx.generation_params,
&ctx.sampling_params, &ctx.sampling_params,
|new_token_id: u32, new_token_logit: f32, is_eos: bool| { stream_ptr,
let response = InferStreamResponse::Intermediate { llama_generate_callback,
token: Token { );
id: new_token_id,
text: "".to_string(), // Make sure we re-keep track of the OpaqueStream box
logprob: new_token_logit, let _ = Box::from_raw(stream_ptr);
special: false,
}, match result {
top_tokens: vec![],
};
println!("Generated token: {response:?}");
// let _ = tokio::spawn(async {
// match ctx.stream.send(Ok(response)) {
// Ok(_) => {}
// Err(ref err) => {
// error!(
// "Failed to send back token to the client: {}",
// err.to_string()
// );
// }
// }
// });
},
) {
Ok(n_tokens) => { Ok(n_tokens) => {
unsafe { unsafe {
ctx.generated_tokens.set_len(n_tokens); ctx.generated_tokens.set_len(n_tokens);
} }
println!( println!(
"Generated {} tokens -> {:?}", "Generated {} tokens -> {:?}",
n_tokens, &ctx.generated_tokens n_tokens, &ctx.generated_tokens

View File

@ -1,4 +1,6 @@
use crate::ffi::SamplingParams; use crate::ffi::SamplingParams;
use text_generation_router::infer::{InferError, InferStreamResponse};
use tokio::sync::mpsc::UnboundedSender;
pub mod backend; pub mod backend;
@ -14,6 +16,8 @@ impl Default for SamplingParams {
} }
} }
struct OpaqueStream(UnboundedSender<Result<InferStreamResponse, InferError>>);
#[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")] #[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")]
mod ffi { mod ffi {
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
@ -31,6 +35,10 @@ mod ffi {
seed: u64, seed: u64,
} }
extern "Rust" {
type OpaqueStream;
}
unsafe extern "C++" { unsafe extern "C++" {
include!("backends/llamacpp/csrc/ffi.hpp"); include!("backends/llamacpp/csrc/ffi.hpp");
@ -47,13 +55,22 @@ mod ffi {
#[rust_name = "create_single_worker_backend"] #[rust_name = "create_single_worker_backend"]
fn create_single_worker_backend(modelPath: &str) -> Result<UniquePtr<LlamaCppBackendImpl>>; fn create_single_worker_backend(modelPath: &str) -> Result<UniquePtr<LlamaCppBackendImpl>>;
fn generate( // fn generate(
// self: Pin<&mut LlamaCppBackendImpl>,
// tokens: &[u32],
// generated: &mut [u32],
// generation_params: GenerationParams,
// sampling_params: &SamplingParams,
// ) -> Result<usize>;
unsafe fn stream(
self: Pin<&mut LlamaCppBackendImpl>, self: Pin<&mut LlamaCppBackendImpl>,
tokens: &[u32], tokens: &[u32],
generated: &mut [u32], generated: &mut [u32],
generation_params: &GenerationParams, generation_params: GenerationParams,
sampling_params: &SamplingParams, sampling_params: &SamplingParams,
callback: fn(u32, f32, bool), stream: *mut OpaqueStream,
callback: unsafe fn(*mut OpaqueStream, u32, f32, bool),
) -> Result<usize>; ) -> Result<usize>;
} }
} }