mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-27 10:20:17 +00:00
feat(backend): wrap Arc tokenizer to avoid duplicating
This commit is contained in:
parent
57b215467b
commit
6f059c4b5d
@ -53,10 +53,10 @@ pub(crate) struct GenerationContext {
|
||||
pub(crate) sampling_params: SamplingParams,
|
||||
}
|
||||
|
||||
pub(crate) struct InferContext {
|
||||
pub(crate) struct InferContext<'a> {
|
||||
pub(crate) start: Instant,
|
||||
pub(crate) stream: UnboundedSender<InferResult>,
|
||||
pub(crate) tokenizer: Tokenizer,
|
||||
pub(crate) tokenizer: &'a Tokenizer,
|
||||
pub(crate) generation: GenerationContext,
|
||||
}
|
||||
|
||||
@ -69,11 +69,6 @@ pub enum LlamaCppBackendError {
|
||||
ModelInitializationFailed(PathBuf, String),
|
||||
}
|
||||
|
||||
// pub struct LlamaCppBackend {
|
||||
// backlog: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||
// _scheduler_handle: JoinHandle<()>,
|
||||
// }
|
||||
|
||||
struct LlamaCppWorker {
|
||||
sender: Sender<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||
handle: JoinHandle<()>,
|
||||
@ -95,7 +90,7 @@ impl LlamaCppBackend {
|
||||
|
||||
pub fn new<P: AsRef<Path>>(
|
||||
model_path: P,
|
||||
tokenizer: Tokenizer,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
num_cores_per_instance: u16,
|
||||
) -> Result<Self, LlamaCppBackendError> {
|
||||
let shared_path = Arc::new(model_path);
|
||||
@ -110,7 +105,7 @@ impl LlamaCppBackend {
|
||||
0 => {
|
||||
let worker = Self::allocate_worker(path)?;
|
||||
let (sender, receiver) = channel();
|
||||
let handle = spawn(|| scheduler_loop(worker, tokenizer, receiver));
|
||||
let handle = spawn(move || scheduler_loop(worker, tokenizer, receiver));
|
||||
LlamaCppBackend::Single(LlamaCppWorker { sender, handle })
|
||||
}
|
||||
_ => panic!("No supported yet"),
|
||||
@ -186,7 +181,7 @@ fn llama_generate_callback(
|
||||
|
||||
fn scheduler_loop(
|
||||
mut backend: UniquePtr<LlamaCppWorkerFrontend>,
|
||||
tokenizer: Tokenizer,
|
||||
tokenizer: Arc<Tokenizer>,
|
||||
backlog: Receiver<(GenerationContext, UnboundedSender<InferResult>)>,
|
||||
) {
|
||||
// This loop will mostly decode single token at every step, so no need to rely on parallelism
|
||||
@ -195,17 +190,15 @@ fn scheduler_loop(
|
||||
loop {
|
||||
if let Ok((generation, stream)) = backlog.recv() {
|
||||
let start = Instant::now();
|
||||
let tokenizer = tokenizer.clone();
|
||||
let generation_params = generation.generation_params; // copy
|
||||
let sampling_params = generation.sampling_params; // copy
|
||||
let input_tokens = Arc::clone(&generation.input_tokens);
|
||||
|
||||
// Creating the whole InferContext and pushing it to the heap
|
||||
{
|
||||
let ctx = Box::new(InferContext {
|
||||
start,
|
||||
stream,
|
||||
tokenizer,
|
||||
tokenizer: &tokenizer,
|
||||
generation,
|
||||
});
|
||||
|
||||
@ -226,7 +219,6 @@ fn scheduler_loop(
|
||||
// Make sure we re-keep track of the OpaqueStream box
|
||||
let _ = Box::from_raw(boxed_ctx);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!("IPC channel is closed, exiting the scheduler loop");
|
||||
break;
|
||||
|
@ -33,7 +33,7 @@ mod ffi {
|
||||
}
|
||||
|
||||
extern "Rust" {
|
||||
type InferContext;
|
||||
type InferContext<'a>;
|
||||
}
|
||||
|
||||
unsafe extern "C++" {
|
||||
|
@ -1,5 +1,6 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use text_generation_backend_llamacpp::backend::{LlamaCppBackend, LlamaCppBackendError};
|
||||
use text_generation_router::server::ApiDoc;
|
||||
use text_generation_router::{server, usage_stats};
|
||||
@ -162,8 +163,10 @@ async fn main() -> Result<(), RouterError> {
|
||||
user_agent: Default::default(),
|
||||
auth_token,
|
||||
};
|
||||
let tokenizer = tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
|
||||
.expect("Failed to retrieve tokenizer");
|
||||
let tokenizer = Arc::new(
|
||||
tokenizers::Tokenizer::from_pretrained(tokenizer_name.clone(), Some(options))
|
||||
.expect("Failed to retrieve tokenizer"),
|
||||
);
|
||||
let backend = LlamaCppBackend::new(gguf_path, tokenizer, num_cores_per_instance.unwrap_or(0))?;
|
||||
|
||||
// Run server
|
||||
|
Loading…
Reference in New Issue
Block a user