feat(backend): wrap Arc tokenizer to avoid duplicating

This commit is contained in:
Morgan Funtowicz 2024-11-14 08:41:38 +01:00
parent 57b215467b
commit 6f059c4b5d
3 changed files with 32 additions and 37 deletions

View File

@ -53,10 +53,10 @@ pub(crate) struct GenerationContext {
pub(crate) sampling_params: SamplingParams,
}
pub(crate) struct InferContext {
pub(crate) struct InferContext<'a> {
pub(crate) start: Instant,
pub(crate) stream: UnboundedSender<InferResult>,
pub(crate) tokenizer: Tokenizer,
pub(crate) tokenizer: &'a Tokenizer,
pub(crate) generation: GenerationContext,
}
@ -69,11 +69,6 @@ pub enum LlamaCppBackendError {
ModelInitializationFailed(PathBuf, String),
}
// pub struct LlamaCppBackend {
// backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
// _scheduler_handle: JoinHandle<()>,
// }
struct LlamaCppWorker {
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
handle: JoinHandle<()>,
@ -95,7 +90,7 @@ impl LlamaCppBackend {
pub fn new<P: AsRef<Path>>(
model_path: P,
tokenizer: Tokenizer,
tokenizer: Arc<Tokenizer>,
num_cores_per_instance: u16,
) -> Result<Self, LlamaCppBackendError> {
let shared_path = Arc::new(model_path);
@ -110,7 +105,7 @@ impl LlamaCppBackend {
0 => {
let worker = Self::allocate_worker(path)?;
let (sender, receiver) = channel();
let handle = spawn(|| scheduler_loop(worker, tokenizer, receiver));
let handle = spawn(move || scheduler_loop(worker, tokenizer, receiver));
LlamaCppBackend::Single(LlamaCppWorker { sender, handle })
}
_ => panic!("No supported yet"),
@ -186,7 +181,7 @@ fn llama_generate_callback(
fn scheduler_loop(
mut backend: UniquePtr<LlamaCppWorkerFrontend>,
tokenizer: Tokenizer,
tokenizer: Arc<Tokenizer>,
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
) {
// This loop will mostly decode single token at every step, so no need to rely on parallelism
@ -195,37 +190,34 @@ fn scheduler_loop(
loop {
if let Ok((generation, stream)) = backlog.recv() {
let start = Instant::now();
let tokenizer = tokenizer.clone();
let generation_params = generation.generation_params; // copy
let sampling_params = generation.sampling_params; // copy
let input_tokens = Arc::clone(&generation.input_tokens);
// Creating the whole InferContext and pushing it to the heap
{
let ctx = Box::new(InferContext {
start,
stream,
tokenizer,
generation,
});
let ctx = Box::new(InferContext {
start,
stream,
tokenizer: &tokenizer,
generation,
});
// We leak the box to avoid it being freed after the first callback call
// when going out of scope
unsafe {
let boxed_ctx = Box::into_raw(ctx);
if let Err(e) = backend.pin_mut().stream(
&input_tokens,
generation_params,
&sampling_params,
boxed_ctx,
llama_generate_callback,
) {
error!("Error while decoding tokens... {}", e.what());
}
// Make sure we re-keep track of the OpaqueStream box
let _ = Box::from_raw(boxed_ctx);
// We leak the box to avoid it being freed after the first callback call
// when going out of scope
unsafe {
let boxed_ctx = Box::into_raw(ctx);
if let Err(e) = backend.pin_mut().stream(
&input_tokens,
generation_params,
&sampling_params,
boxed_ctx,
llama_generate_callback,
) {
error!("Error while decoding tokens... {}", e.what());
}
// Make sure we re-keep track of the OpaqueStream box
let _ = Box::from_raw(boxed_ctx);
}
} else {
info!("IPC channel is closed, exiting the scheduler loop");

View File

@ -33,7 +33,7 @@ mod ffi {
}
extern "Rust" {
type InferContext;
type InferContext<'a>;
}
unsafe extern "C++" {

View File

@ -1,5 +1,6 @@
use clap::{Parser, Subcommand};
use std::path::PathBuf;
use std::sync::Arc;
use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError};
use text_generation_router::server::ApiDoc;
use text_generation_router::{server, usage_stats};
@ -162,8 +163,10 @@ async fn main() -> Result<(), RouterError> {
user_agent: Default::default(),
auth_token,
};
let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
.expect("Failed to retrieve tokenizer");
let tokenizer = Arc::new(
tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
.expect("Failed to retrieve tokenizer"),
);
let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?;
// Run server