diff --git a/backends/llamacpp/src/backend.rs b/backends/llamacpp/src/backend.rs index 2dd5b70d..dc29b707 100644 --- a/backends/llamacpp/src/backend.rs +++ b/backends/llamacpp/src/backend.rs @@ -53,10 +53,10 @@ pub(crate) struct GenerationContext { pub(crate) sampling_params: SamplingParams, } -pub(crate) struct InferContext { +pub(crate) struct InferContext<'a> { pub(crate) start: Instant, pub(crate) stream: UnboundedSender, - pub(crate) tokenizer: Tokenizer, + pub(crate) tokenizer: &'a Tokenizer, pub(crate) generation: GenerationContext, } @@ -69,11 +69,6 @@ pub enum LlamaCppBackendError { ModelInitializationFailed(PathBuf, String), } -// pub struct LlamaCppBackend { -// backlog: Sender<(GenerationContext, UnboundedSender)>, -// _scheduler_handle: JoinHandle<()>, -// } - struct LlamaCppWorker { sender: Sender<(GenerationContext, UnboundedSender)>, handle: JoinHandle<()>, @@ -95,7 +90,7 @@ impl LlamaCppBackend { pub fn new>( model_path: P, - tokenizer: Tokenizer, + tokenizer: Arc, num_cores_per_instance: u16, ) -> Result { let shared_path = Arc::new(model_path); @@ -110,7 +105,7 @@ impl LlamaCppBackend { 0 => { let worker = Self::allocate_worker(path)?; let (sender, receiver) = channel(); - let handle = spawn(|| scheduler_loop(worker, tokenizer, receiver)); + let handle = spawn(move || scheduler_loop(worker, tokenizer, receiver)); LlamaCppBackend::Single(LlamaCppWorker { sender, handle }) } _ => panic!("No supported yet"), @@ -186,7 +181,7 @@ fn llama_generate_callback( fn scheduler_loop( mut backend: UniquePtr, - tokenizer: Tokenizer, + tokenizer: Arc, backlog: Receiver<(GenerationContext, UnboundedSender)>, ) { // This loop will mostly decode single token at every step, so no need to rely on parallelism @@ -195,37 +190,34 @@ fn scheduler_loop( loop { if let Ok((generation, stream)) = backlog.recv() { let start = Instant::now(); - let tokenizer = tokenizer.clone(); let generation_params = generation.generation_params; // copy let sampling_params = generation.sampling_params; // copy let input_tokens = Arc::clone(&generation.input_tokens); // Creating the whole InferContext and pushing it to the heap - { - let ctx = Box::new(InferContext { - start, - stream, - tokenizer, - generation, - }); + let ctx = Box::new(InferContext { + start, + stream, + tokenizer: &tokenizer, + generation, + }); - // We leak the box to avoid it being freed after the first callback call - // when going out of scope - unsafe { - let boxed_ctx = Box::into_raw(ctx); - if let Err(e) = backend.pin_mut().stream( - &input_tokens, - generation_params, - &sampling_params, - boxed_ctx, - llama_generate_callback, - ) { - error!("Error while decoding tokens... {}", e.what()); - } - - // Make sure we re-keep track of the OpaqueStream box - let _ = Box::from_raw(boxed_ctx); + // We leak the box to avoid it being freed after the first callback call + // when going out of scope + unsafe { + let boxed_ctx = Box::into_raw(ctx); + if let Err(e) = backend.pin_mut().stream( + &input_tokens, + generation_params, + &sampling_params, + boxed_ctx, + llama_generate_callback, + ) { + error!("Error while decoding tokens... {}", e.what()); } + + // Make sure we re-keep track of the OpaqueStream box + let _ = Box::from_raw(boxed_ctx); } } else { info!("IPC channel is closed, exiting the scheduler loop"); diff --git a/backends/llamacpp/src/lib.rs b/backends/llamacpp/src/lib.rs index 4f0fa800..8fc98955 100644 --- a/backends/llamacpp/src/lib.rs +++ b/backends/llamacpp/src/lib.rs @@ -33,7 +33,7 @@ mod ffi { } extern "Rust" { - type InferContext; + type InferContext<'a>; } unsafe extern "C++" { diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index a2abd555..adc183ed 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -1,5 +1,6 @@ use clap::{Parser, Subcommand}; use std::path::PathBuf; +use std::sync::Arc; use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError}; use text_generation_router::server::ApiDoc; use text_generation_router::{server, usage_stats}; @@ -162,8 +163,10 @@ async fn main() -> Result<(), RouterError> { user_agent: Default::default(), auth_token, }; - let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options)) - .expect("Failed to retrieve tokenizer"); + let tokenizer = Arc::new( + tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options)) + .expect("Failed to retrieve tokenizer"), + ); let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?; // Run server