feat(backend): wrap Arc tokenizer to avoid duplicating

This commit is contained in:
Morgan Funtowicz 2024-11-14 08:41:38 +01:00
parent 57b215467b
commit 6f059c4b5d
3 changed files with 32 additions and 37 deletions

View File

@ -53,10 +53,10 @@ pub(crate) struct GenerationContext {
pub(crate) sampling_params: SamplingParams, pub(crate) sampling_params: SamplingParams,
} }
pub(crate) struct InferContext { pub(crate) struct InferContext<'a> {
pub(crate) start: Instant, pub(crate) start: Instant,
pub(crate) stream: UnboundedSender<InferResult>, pub(crate) stream: UnboundedSender<InferResult>,
pub(crate) tokenizer: Tokenizer, pub(crate) tokenizer: &'a Tokenizer,
pub(crate) generation: GenerationContext, pub(crate) generation: GenerationContext,
} }
@ -69,11 +69,6 @@ pub enum LlamaCppBackendError {
ModelInitializationFailed(PathBuf, String), ModelInitializationFailed(PathBuf, String),
} }
// pub struct LlamaCppBackend {
// backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
// _scheduler_handle: JoinHandle<()>,
// }
struct LlamaCppWorker { struct LlamaCppWorker {
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>, sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
handle: JoinHandle<()>, handle: JoinHandle<()>,
@ -95,7 +90,7 @@ impl LlamaCppBackend {
pub fn new<P: AsRef<Path>>( pub fn new<P: AsRef<Path>>(
model_path: P, model_path: P,
tokenizer: Tokenizer, tokenizer: Arc<Tokenizer>,
num_cores_per_instance: u16, num_cores_per_instance: u16,
) -> Result<Self, LlamaCppBackendError> { ) -> Result<Self, LlamaCppBackendError> {
let shared_path = Arc::new(model_path); let shared_path = Arc::new(model_path);
@ -110,7 +105,7 @@ impl LlamaCppBackend {
0 => { 0 => {
let worker = Self::allocate_worker(path)?; let worker = Self::allocate_worker(path)?;
let (sender, receiver) = channel(); let (sender, receiver) = channel();
let handle = spawn(|| scheduler_loop(worker, tokenizer, receiver)); let handle = spawn(move || scheduler_loop(worker, tokenizer, receiver));
LlamaCppBackend::Single(LlamaCppWorker { sender, handle }) LlamaCppBackend::Single(LlamaCppWorker { sender, handle })
} }
_ => panic!("No supported yet"), _ => panic!("No supported yet"),
@ -186,7 +181,7 @@ fn llama_generate_callback(
fn scheduler_loop( fn scheduler_loop(
mut backend: UniquePtr<LlamaCppWorkerFrontend>, mut backend: UniquePtr<LlamaCppWorkerFrontend>,
tokenizer: Tokenizer, tokenizer: Arc<Tokenizer>,
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>, backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
) { ) {
// This loop will mostly decode single token at every step, so no need to rely on parallelism // This loop will mostly decode single token at every step, so no need to rely on parallelism
@ -195,17 +190,15 @@ fn scheduler_loop(
loop { loop {
if let Ok((generation, stream)) = backlog.recv() { if let Ok((generation, stream)) = backlog.recv() {
let start = Instant::now(); let start = Instant::now();
let tokenizer = tokenizer.clone();
let generation_params = generation.generation_params; // copy let generation_params = generation.generation_params; // copy
let sampling_params = generation.sampling_params; // copy let sampling_params = generation.sampling_params; // copy
let input_tokens = Arc::clone(&generation.input_tokens); let input_tokens = Arc::clone(&generation.input_tokens);
// Creating the whole InferContext and pushing it to the heap // Creating the whole InferContext and pushing it to the heap
{
let ctx = Box::new(InferContext { let ctx = Box::new(InferContext {
start, start,
stream, stream,
tokenizer, tokenizer: &tokenizer,
generation, generation,
}); });
@ -226,7 +219,6 @@ fn scheduler_loop(
// Make sure we re-keep track of the OpaqueStream box // Make sure we re-keep track of the OpaqueStream box
let _ = Box::from_raw(boxed_ctx); let _ = Box::from_raw(boxed_ctx);
} }
}
} else { } else {
info!("IPC channel is closed, exiting the scheduler loop"); info!("IPC channel is closed, exiting the scheduler loop");
break; break;

View File

@ -33,7 +33,7 @@ mod ffi {
} }
extern "Rust" { extern "Rust" {
type InferContext; type InferContext<'a>;
} }
unsafe extern "C++" { unsafe extern "C++" {

View File

@ -1,5 +1,6 @@
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc;
use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError}; use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError};
use text_generation_router::server::ApiDoc; use text_generation_router::server::ApiDoc;
use text_generation_router::{server, usage_stats}; use text_generation_router::{server, usage_stats};
@ -162,8 +163,10 @@ async fn main() -> Result<(), RouterError> {
user_agent: Default::default(), user_agent: Default::default(),
auth_token, auth_token,
}; };
let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options)) let tokenizer = Arc::new(
.expect("Failed to retrieve tokenizer"); tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
.expect("Failed to retrieve tokenizer"),
);
let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?; let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?;
// Run server // Run server