feat(backend): wrap Arc tokenizer to avoid duplicating

This commit is contained in:
Morgan Funtowicz 2024-11-14 08:41:38 +01:00
parent 57b215467b
commit 6f059c4b5d
3 changed files with 32 additions and 37 deletions

View File

@ -53,10 +53,10 @@ pub(crate) struct GenerationContext {
pub(crate) sampling_params: SamplingParams, pub(crate) sampling_params: SamplingParams,
} }
pub(crate) struct InferContext { pub(crate) struct InferContext<'a> {
pub(crate) start: Instant, pub(crate) start: Instant,
pub(crate) stream: UnboundedSender<InferResult>, pub(crate) stream: UnboundedSender<InferResult>,
pub(crate) tokenizer: Tokenizer, pub(crate) tokenizer: &'a Tokenizer,
pub(crate) generation: GenerationContext, pub(crate) generation: GenerationContext,
} }
@ -69,11 +69,6 @@ pub enum LlamaCppBackendError {
ModelInitializationFailed(PathBuf, String), ModelInitializationFailed(PathBuf, String),
} }
// pub struct LlamaCppBackend {
// backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
// _scheduler_handle: JoinHandle<()>,
// }
struct LlamaCppWorker { struct LlamaCppWorker {
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>, sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
handle: JoinHandle<()>, handle: JoinHandle<()>,
@ -95,7 +90,7 @@ impl LlamaCppBackend {
pub fn new<P: AsRef<Path>>( pub fn new<P: AsRef<Path>>(
model_path: P, model_path: P,
tokenizer: Tokenizer, tokenizer: Arc<Tokenizer>,
num_cores_per_instance: u16, num_cores_per_instance: u16,
) -> Result<Self, LlamaCppBackendError> { ) -> Result<Self, LlamaCppBackendError> {
let shared_path = Arc::new(model_path); let shared_path = Arc::new(model_path);
@ -110,7 +105,7 @@ impl LlamaCppBackend {
0 => { 0 => {
let worker = Self::allocate_worker(path)?; let worker = Self::allocate_worker(path)?;
let (sender, receiver) = channel(); let (sender, receiver) = channel();
let handle = spawn(|| scheduler_loop(worker, tokenizer, receiver)); let handle = spawn(move || scheduler_loop(worker, tokenizer, receiver));
LlamaCppBackend::Single(LlamaCppWorker { sender, handle }) LlamaCppBackend::Single(LlamaCppWorker { sender, handle })
} }
_ => panic!("No supported yet"), _ => panic!("No supported yet"),
@ -186,7 +181,7 @@ fn llama_generate_callback(
fn scheduler_loop( fn scheduler_loop(
mut backend: UniquePtr<LlamaCppWorkerFrontend>, mut backend: UniquePtr<LlamaCppWorkerFrontend>,
tokenizer: Tokenizer, tokenizer: Arc<Tokenizer>,
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>, backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
) { ) {
// This loop will mostly decode single token at every step, so no need to rely on parallelism // This loop will mostly decode single token at every step, so no need to rely on parallelism
@ -195,37 +190,34 @@ fn scheduler_loop(
loop { loop {
if let Ok((generation, stream)) = backlog.recv() { if let Ok((generation, stream)) = backlog.recv() {
let start = Instant::now(); let start = Instant::now();
let tokenizer = tokenizer.clone();
let generation_params = generation.generation_params; // copy let generation_params = generation.generation_params; // copy
let sampling_params = generation.sampling_params; // copy let sampling_params = generation.sampling_params; // copy
let input_tokens = Arc::clone(&generation.input_tokens); let input_tokens = Arc::clone(&generation.input_tokens);
// Creating the whole InferContext and pushing it to the heap // Creating the whole InferContext and pushing it to the heap
{ let ctx = Box::new(InferContext {
let ctx = Box::new(InferContext { start,
start, stream,
stream, tokenizer: &tokenizer,
tokenizer, generation,
generation, });
});
// We leak the box to avoid it being freed after the first callback call // We leak the box to avoid it being freed after the first callback call
// when going out of scope // when going out of scope
unsafe { unsafe {
let boxed_ctx = Box::into_raw(ctx); let boxed_ctx = Box::into_raw(ctx);
if let Err(e) = backend.pin_mut().stream( if let Err(e) = backend.pin_mut().stream(
&input_tokens, &input_tokens,
generation_params, generation_params,
&sampling_params, &sampling_params,
boxed_ctx, boxed_ctx,
llama_generate_callback, llama_generate_callback,
) { ) {
error!("Error while decoding tokens... {}", e.what()); error!("Error while decoding tokens... {}", e.what());
}
// Make sure we re-keep track of the OpaqueStream box
let _ = Box::from_raw(boxed_ctx);
} }
// Make sure we re-keep track of the OpaqueStream box
let _ = Box::from_raw(boxed_ctx);
} }
} else { } else {
info!("IPC channel is closed, exiting the scheduler loop"); info!("IPC channel is closed, exiting the scheduler loop");

View File

@ -33,7 +33,7 @@ mod ffi {
} }
extern "Rust" { extern "Rust" {
type InferContext; type InferContext<'a>;
} }
unsafe extern "C++" { unsafe extern "C++" {

View File

@ -1,5 +1,6 @@
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc;
use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError}; use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError};
use text_generation_router::server::ApiDoc; use text_generation_router::server::ApiDoc;
use text_generation_router::{server, usage_stats}; use text_generation_router::{server, usage_stats};
@ -162,8 +163,10 @@ async fn main() -> Result<(), RouterError> {
user_agent: Default::default(), user_agent: Default::default(),
auth_token, auth_token,
}; };
let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options)) let tokenizer = Arc::new(
.expect("Failed to retrieve tokenizer"); tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
.expect("Failed to retrieve tokenizer"),
);
let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?; let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?;
// Run server // Run server