diff --git a/backends/llamacpp/csrc/ffi.hpp b/backends/llamacpp/csrc/ffi.hpp index 5c404b01..c823b72b 100644 --- a/backends/llamacpp/csrc/ffi.hpp +++ b/backends/llamacpp/csrc/ffi.hpp @@ -21,6 +21,7 @@ namespace huggingface::tgi::backends::llamacpp { #include "backends/llamacpp/src/lib.rs.h" +#include "rust/cxx.h" namespace huggingface::tgi::backends::llamacpp { @@ -61,17 +62,22 @@ namespace huggingface::tgi::backends::llamacpp { explicit llama_cpp_backend_impl_t(multi_worker_backend_t &&backend) : mInner_(std::move(backend)) {} - size_t generate( + size_t stream( rust::Slice input_tokens, rust::Slice generated_tokens, - const generation_params_t &generation_params, + const generation_params_t generation_params, const sampling_params_t &sampling_params, - rust::Fn callback + OpaqueStream *stream, + rust::Fn callback ) { // Define the visitor lambda function which requires the has_emplace_generate constraint on T - static auto inner_fw = [=, &generation_params, &sampling_params](T &&backend) + static auto inner_fw = [=, &sampling_params, &stream, &callback](T &&backend) -> std::expected { + auto context_forwarding_callback = [=, &stream](uint32_t new_token_id, float_t logits, bool is_eos){ + callback(stream, new_token_id, logits, is_eos); + }; + // Ask the compiler to create view over Rust slice transmuting from uint32_t* to int32_t* auto input_tokens_v = std::span(reinterpret_cast(input_tokens.data()), input_tokens.size()); @@ -79,7 +85,12 @@ namespace huggingface::tgi::backends::llamacpp { std::span(reinterpret_cast(generated_tokens.data()), generated_tokens.size()); return backend.generate( - input_tokens_v, generated_tokens_v, generation_params, sampling_params, callback); + input_tokens_v, + generated_tokens_v, + generation_params, + sampling_params, + context_forwarding_callback + ); }; if (const auto result = std::visit(inner_fw, mInner_); result.has_value()) { diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 670f4397..09afbc7b 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -1,23 +1,27 @@ use crate::ffi::{ create_single_worker_backend, GenerationParams, LlamaCppBackendImpl, SamplingParams, }; +use crate::OpaqueStream; use async_trait::async_trait; use cxx::UniquePtr; use std::path::{Path, PathBuf}; use std::sync::mpsc::{channel, Receiver, SendError, Sender}; use std::sync::Arc; use std::thread::{spawn, JoinHandle}; -use text_generation_router::infer::{Backend, InferError, InferStreamResponse}; +use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; use text_generation_router::validation::{ ValidGenerateRequest, ValidParameters, ValidStoppingParameters, }; -use text_generation_router::Token; +use text_generation_router::{FinishReason, Token}; use thiserror::Error; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::sync::TryAcquireError; +use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{error, info}; +type BoxedOpaqueStream = Box; + unsafe impl Send for LlamaCppBackendImpl {} impl From<&ValidParameters> for SamplingParams { @@ -86,7 +90,7 @@ impl LlamaCppBackend { ); let (submitter, receiver) = channel(); - let handle = spawn(|| scheduler_loop(backend, receiver)); + let handle = unsafe { spawn(|| scheduler_loop(backend, receiver)) }; Ok(Self { backlog: submitter, scheduler_handle: handle, @@ -94,47 +98,59 @@ impl LlamaCppBackend { } } -fn scheduler_loop( +fn llama_generate_callback( + channel: *mut OpaqueStream, + new_token_id: u32, + new_token_logit: f32, + is_eos: bool, +) { + let response = InferStreamResponse::Intermediate { + token: Token { + id: new_token_id, + text: "".to_string(), + logprob: new_token_logit, + special: false, + }, + top_tokens: vec![], + }; + println!("Generated token: {new_token_id} -> logits={new_token_logit}, is_eos={is_eos}"); + + unsafe { + if let Err(ref err) = (*channel).0.send(Ok(response)) { + error!( + "Failed to send back token to the client: {}", + err.to_string() + ); + } + } +} + +unsafe fn scheduler_loop( mut backend: UniquePtr, mut backlog: Receiver, ) { loop { - println!("Looping"); if let Ok(mut ctx) = backlog.recv() { - println!("{ctx:?}, {}", &ctx.generated_tokens.capacity()); - match backend.pin_mut().generate( + let stream = BoxedOpaqueStream::new(OpaqueStream(ctx.stream)); + let stream_ptr = Box::into_raw(stream); + let result = backend.pin_mut().stream( &ctx.input_tokens, &mut ctx.generated_tokens, - &ctx.generation_params, + ctx.generation_params, &ctx.sampling_params, - |new_token_id: u32, new_token_logit: f32, is_eos: bool| { - let response = InferStreamResponse::Intermediate { - token: Token { - id: new_token_id, - text: "".to_string(), - logprob: new_token_logit, - special: false, - }, - top_tokens: vec![], - }; - println!("Generated token: {response:?}"); - // let _ = tokio::spawn(async { - // match ctx.stream.send(Ok(response)) { - // Ok(_) => {} - // Err(ref err) => { - // error!( - // "Failed to send back token to the client: {}", - // err.to_string() - // ); - // } - // } - // }); - }, - ) { + stream_ptr, + llama_generate_callback, + ); + + // Make sure we re-keep track of the OpaqueStream box + let _ = Box::from_raw(stream_ptr); + + match result { Ok(n_tokens) => { unsafe { ctx.generated_tokens.set_len(n_tokens); } + println!( "Generated {} tokens -> {:?}", n_tokens, &ctx.generated_tokens diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index 489188c1..f923526f 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -1,4 +1,6 @@ use crate::ffi::SamplingParams; +use text_generation_router::infer::{InferError, InferStreamResponse}; +use tokio::sync::mpsc::UnboundedSender; pub mod backend; @@ -14,6 +16,8 @@ impl Default for SamplingParams { } } +struct OpaqueStream(UnboundedSender>); + #[cxx::bridge(namespace = "huggingface::tgi::backends::llamacpp")] mod ffi { #[derive(Debug, Copy, Clone)] @@ -31,6 +35,10 @@ mod ffi { seed: u64, } + extern "Rust" { + type OpaqueStream; + } + unsafe extern "C++" { include!("backends/llamacpp/csrc/ffi.hpp"); @@ -47,13 +55,22 @@ mod ffi { #[rust_name = "create_single_worker_backend"] fn create_single_worker_backend(modelPath: &str) -> Result>; - fn generate( + // fn generate( + // self: Pin<&mut LlamaCppBackendImpl>, + // tokens: &[u32], + // generated: &mut [u32], + // generation_params: GenerationParams, + // sampling_params: &SamplingParams, + // ) -> Result; + + unsafe fn stream( self: Pin<&mut LlamaCppBackendImpl>, tokens: &[u32], generated: &mut [u32], - generation_params: &GenerationParams, + generation_params: GenerationParams, sampling_params: &SamplingParams, - callback: fn(u32, f32, bool), + stream: *mut OpaqueStream, + callback: unsafe fn(*mut OpaqueStream, u32, f32, bool), ) -> Result; } }